dfm/tinygp

Using `tinygp` with "`kernel` as a Pytree" approach

patel-zeel opened this issue · 1 comments

Hi @dfm, as we have discussed in #79, I have the following thoughts about presenting the idea of using kernels as Pytrees.

  1. Creating a model by providing the kernel as a param seems a more natural and efficient way of using tinygp. Considering that, shall we add or update the following code (Colab link) to the Tips section?:

(log_kernel may not be the best name here, so you may suggest something better)

import jax.numpy as jnp


def build_gp(params):
    log_kernel, log_noise = params
    kernel = jax.tree_map(jnp.exp, log_kernel)
    noise = jnp.exp(log_noise)
    return GaussianProcess(kernel, X, diag=noise)


@jax.jit
def loss(params):
    gp = build_gp(params)
    return -gp.log_probability(y)

log_amp = -0.1
log_scale = 0.0
log_noise = -1.0
log_kernel =  log_amp * kernels.ExpSquared(scale=log_scale)
params = (log_kernel, log_noise)
loss(params)
  1. I think "kernel as a Pytree" approach will have the most impact in the Getting Started section due to the use of kernel combinations. The current code can be translated to something like the following (Colab link). I could not make jaxopt work due to some reasons (maybe it is detecting parameters as floats instead of DeviceArrays) thus, I used optax instead (error trace is present in the colab).
import jax
import jax.numpy as jnp

from tinygp import kernels, GaussianProcess

jax.config.update("jax_enable_x64", True)

def build_kernel():
    k1 = np.log(66.0) * kernels.ExpSquared(np.log(67.0))
    k2 = (np.log(2.4)
          * kernels.ExpSquared(np.log(90.0))
          * kernels.ExpSineSquared(
              scale=np.log(1.0),
              gamma=np.log(4.3),
          )
      )
    k3 = np.log(0.66) * kernels.RationalQuadratic(
        alpha=np.log(1.2), scale=np.log(0.78)
    )
    k4 = np.log(0.18) * kernels.ExpSquared(np.log(1.6))
    kernel = k1 + k2 + k3 + k4

    return kernel

def build_gp(params, X):
    # We want most of our parameters to be positive so we take the `exp` here
    # Note that we're using `jnp` instead of `np`
    kernel, noise, mean = params
    kernel = jax.tree_map(jnp.exp, kernel)
    return GaussianProcess(kernel, X, diag=jnp.exp(noise), mean=mean)


def neg_log_likelihood(params, X, y):
    gp = build_gp(params, X)
    return -gp.log_probability(y)

kernel = build_kernel()
log_noise = np.log(0.19)
mean = np.float64(340.0)

params_init = (kernel, log_noise, mean)

# `jax` can be used to differentiate functions, and also note that we're calling
# `jax.jit` for the best performance.
obj = jax.jit(jax.value_and_grad(neg_log_likelihood))

print(f"Initial negative log likelihood: {obj(params_init, t, y)[0]}")
print(
    f"Gradient of the negative log likelihood, wrt the parameters:\n{obj(params_init, t, y)[1]}"
)
  1. In the Custom Kernels section, you have mentioned the following:

Besides describing this interface, we also show how tinygp can support arbitrary [JAX pytrees] (https://jax.readthedocs.io/en/latest/pytrees.html) as input.

I did not find something related to the above line in the same section. Was this line written to show something like we are discussing now? In that case, I can modify the current code for the spectral mixture kernel to showcase the new approach.

Please let me know your thoughts on these proposals.

P.S.: Feel free to drop your quick suggestions directly on the colab as comments!

Edit:
I think to make a new kernel work in the above approach, it needs to be defined something like this, right?

from tinygp.helpers import dataclass, field, JAXArray
@dataclass
class Linear(kernels.Kernel):
    scale: JAXArray = field(default_factory=lambda: jnp.ones(()))
    sigma: JAXArray = field(default_factory=lambda: jnp.zeros(()))
    
    def evaluate(self, X1, X2):
        return (X1 / self.scale) @ (X2 / self.scale) + jnp.square(self.sigma)
dfm commented

Thanks for this @patel-zeel!

Now that you've written up the details I'm a little more hesitant to include this in the docs, specifically because of this log_kernel point. I think that writing something like (e.g.) np.log(66.0) * kernels.ExpSquared(np.log(67.0)) and then tree_map exp-ing it is bad practice and a little misleading (the amplitude is 66.0, not its log, so this is very confusing notation!). I see why you're doing it, and I think it's the only way given the design of tinygp, but I don't think we should encourage it. For example, some kernel parameters aren't required to be everywhere positive so fitting in the log doesn't make sense for those, but this workflow doesn't offer any way to deal with that use case. It is interesting that it works for your use case and it would be nice to document it somewhere, but I'm definitely not keen to incorporate it into the tutorials in its current form, I don't think.

  1. In the Custom Kernels section, you have mentioned the following: ...

Good catch! No this refers to the info that has been moved to the Derivative Observations & Pytree Data tutorial. The offending sentence should just be removed if you want to open a PR to do that!