How to modify scANVI to handle continuous cell labels?
patricks-lab opened this issue · 1 comments
Is your feature request related to a problem? Please describe.
Not a problem - mainly clarification questions. I am wondering how to modify scANVI to handle continuous prediction tasks.
Currently, scANVI assumes that the cell label is drawn from a categorical distribution (https://docs.scvi-tools.org/en/stable/user_guide/models/scanvi.html#generative-process), which is implemented in the code via a cross entropy loss
scvi-tools/scvi/module/_scanvae.py
Lines 271 to 274 in 95f2e1d
Describe the solution you'd like
For continuous prediction tasks, I thought of changing the above cross entropy loss to something like a MSELoss and swapping the scvi.module.Classifier
with a scvi.nn.Decoder
layer w/ one output neuron per continuous prediction task. However, I am not sure what other changes (e.g. in the VAE graphical model and/or code) need to be made to accommodate a continuous (e.g. normal) distribution on the cell type label.
(For example, I'd imagine that the derivation for the ELBO for the unobserved cell label case would change with a continuous distribution https://docs.scvi-tools.org/en/stable/user_guide/models/scanvi.html#training-details) but I'm not quite sure how to flesh out the details in such a case.
Alternatively, I also thought of the possibility of simply attaching a scvi.nn.Decoder head on top of scVI w/ MSE loss to predict a continuous variable, and then doing backpropagation to adjust the scVI's latent space (and hopefully see the latent space develop a continuous trend), which will be probably easier to implement. But i'm not sure how well such a model will perform, especially as it's much less nuanced than the scANVI's graphical model.
Thanks a lot in advance!
Patrick
The structure of scANVI relies on categorical covariates specifically the z1-z2 loss. I would start with scVI and add the continuous classifier using something like MSE loss. You might want to check scANVI if you only have a covariate for some of your cells as it demonstrates how to handle partially labeled data.