-
Notifications
You must be signed in to change notification settings - Fork 32
Description
I recently updated my project to the new version of tinygp and the tutorial on fitting the mean doesn't work anymore :(
steps reproduce the bug:
uv init test-gp
cd test-gp/
uv add jax tinygp jupyter matplotlib optax
curl -O "https://tinygp.readthedocs.io/en/latest/_sources/tutorials/means.ipynb"
uv run jupyter nbconvert --to html --execute means.ipynbwhich returns TypeError: Error interpreting argument to <function GaussianProcess._get_alpha at 0x71c4a264c360> as an abstract array. The problematic value is of type <class 'equinox._module._flatten._Missing'> and was passed to the function at path self.mean_function.value. This typically means that a jit-wrapped function was called with a non-array argument, and this argument was not marked as static using the static_argnums or static_argnames parameters of jax.jit..
The problem is when we do loss(params).
I suspect this issue is related to #200
Love tinygp, thanks for your great work.
update: the code in the section "An alternative workflow" works! So I guess the problem occurs when building the GP with mean=partial(mean_function, params)