Explore `jax.numpy.linalg.svd` as replacement for `sklearn.decomposition.PCA`
Closed this issue · 0 comments
quattro commented
We currently make use of sklearn.decomposition.PCA
as an optional initialization of the mean parameters for the latent factor space. This is a helpful approach that performs well in practice, but depending on the quite large sklearn
package for a single function may be overkill. JAX has an SVD implementation under jax.numpy.linalg.svd
that we could explore as a replacement which would drop the dependency requirement on sklearn
.