mancusolab/susiepca

Explore `jax.numpy.linalg.svd` as replacement for `sklearn.decomposition.PCA`

Closed this issue · 0 comments

We currently make use of sklearn.decomposition.PCA as an optional initialization of the mean parameters for the latent factor space. This is a helpful approach that performs well in practice, but depending on the quite large sklearn package for a single function may be overkill. JAX has an SVD implementation under jax.numpy.linalg.svd that we could explore as a replacement which would drop the dependency requirement on sklearn.