scverse/scvi-tools

How to modify scANVI to handle continuous cell labels?

patricks-lab opened this issue · 1 comments

Is your feature request related to a problem? Please describe.

Not a problem - mainly clarification questions. I am wondering how to modify scANVI to handle continuous prediction tasks.

Currently, scANVI assumes that the cell label is drawn from a categorical distribution (https://docs.scvi-tools.org/en/stable/user_guide/models/scanvi.html#generative-process), which is implemented in the code via a cross entropy loss

ce_loss = F.cross_entropy(
logits,
y.view(-1).long(),
)
between the predicted and actual cell labels.

Describe the solution you'd like

For continuous prediction tasks, I thought of changing the above cross entropy loss to something like a MSELoss and swapping the scvi.module.Classifier with a scvi.nn.Decoder layer w/ one output neuron per continuous prediction task. However, I am not sure what other changes (e.g. in the VAE graphical model and/or code) need to be made to accommodate a continuous (e.g. normal) distribution on the cell type label.

(For example, I'd imagine that the derivation for the ELBO for the unobserved cell label case would change with a continuous distribution https://docs.scvi-tools.org/en/stable/user_guide/models/scanvi.html#training-details) but I'm not quite sure how to flesh out the details in such a case.

Alternatively, I also thought of the possibility of simply attaching a scvi.nn.Decoder head on top of scVI w/ MSE loss to predict a continuous variable, and then doing backpropagation to adjust the scVI's latent space (and hopefully see the latent space develop a continuous trend), which will be probably easier to implement. But i'm not sure how well such a model will perform, especially as it's much less nuanced than the scANVI's graphical model.

Thanks a lot in advance!
Patrick

The structure of scANVI relies on categorical covariates specifically the z1-z2 loss. I would start with scVI and add the continuous classifier using something like MSE loss. You might want to check scANVI if you only have a covariate for some of your cells as it demonstrates how to handle partially labeled data.