probml/dynamax

Support for kmeans initialization with vmap

ghuckins opened this issue · 3 comments

Hi there,

When I try to use vmap to vectorize a function that includes a kmeans initialization, I get the following error:

jax._src.errors.TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on the JAX Tracer object Traced<ShapedArray(float32[11396,7])>with<BatchTrace(level=1/0)>

And here's the code that produces the error:

    hmm = GaussianHMM(latdim, obsdim)
    data1 = jnp.array(data1)
    data2 = jnp.array(data2)
    data1_train = jnp.stack([jnp.concatenate([data1[:i], data1[i+1:]]) for i in range(len(data1))])
    data2_train = jnp.stack([jnp.concatenate([data2[:i], data2[i+1:]]) for i in range(len(data2))])

    base_params1, props1 = hmm.initialize(key=get_key(), method="kmeans", emissions=data1[:length,:,:])
    params1, _ = hmm.fit_em(base_params1, props1, data1[:length,:,:], num_iters=100, verbose=False)
    base_params2, props2 = hmm.initialize(key=get_key(), method="kmeans", emissions=data2[:length,:,:])
    params2, _ = hmm.fit_em(base_params2, props2, data2[:length,:,:], num_iters=100, verbose=False)
    def _fit_fold(train, test, params):
        base_params, props = hmm.initialize(key=get_key(), method="kmeans", emissions=train[:length,:,:])
        fit_params, _ = hmm.fit_em(base_params, props, train[:length,:,:], num_iters=100, verbose=False)
        return (hmm.marginal_log_prob(fit_params, test) > hmm.marginal_log_prob(params, test)).astype(int)

    correct1 = jnp.sum(vmap(_fit_fold, in_axes = [0,0,None])(data1_train,data1,params2))

The error traces back to scikit-learn and Kmeans. The problem seems to be that scikit-learn uses numpy functions and not jax functions. Would it be possible to update hmm.initialize so that it could be use in vectorized functions?

Thanks!

Hi @ghuckins thanks for showing interest in the library! Yes unfortunately the sklearn bits of code won't naturally play nice with lots of jax's tools.

From what I can tell, updating the "kmeans" option in hmm.initialize to use a jax compatible implementation of the kmeans algorithm would involve writing, testing, and maintaining our own jax kmeans implementation which might be outside the scope of this library unfortunately (unless there is a really great demand for it).

Depending on your precise use case there might be some reasonably straightforward work-arounds. For instance, it might be possible to use the sklearn "kmeans" intialization to generate the appropriate initial parameters for your hmms which could then be passed as input into a function which could be vmapped over your data.

I hope that provides some idea of the way forward, if you wanted to share more details about your use case I would be happy to try to give more precise advice.

Hey Giles, thanks for the reply! I actually did find a Jax implementation of k-means online and am working on incorporating it into my own codebase; I could share it once I'm done, if that would be helpful. Just let me know!

Hi @ghuckins . It would be great if you added your jax implementation of kmeans.