inVAE is a conditionally invariant variational autoencoder that identifies both spurious (distractors) and invariant features. It leverages domain variability to learn conditionally invariant representations. We show that inVAE captures biological variations in single-cell datasets obtained from diverse conditions and labs. inVAE incorporates biological covariates and mechanisms such as disease states, to learn an invariant data representation. This improves cell classification accuracy significantly.
-
PyPI only
pip install invae
-
Development Version (latest version on github)
git clone https://github.com/theislab/inVAE.git
cd inVAE
pip install .
Integration of Human Lung Cell Atlas using both healthy and disease samples
- Load the data:
adata = sc.read(path/to/data)
- Optional - Split the data into train, val, test (in supervised case for training classifier as well)
- Initialize the model, either Factorized or Non-Factorized:
from inVAE import FinVAE, NFinVAE`
inv_covar_keys = {
'cont': [],
'cat': ['cell_type', 'disease'] #set to the keys in the adata
}
spur_covar_keys = {
'cont': [],
'cat': ['batch'] #set to the keys in the adata
}
model = FinVAE(
adata = adata_train,
layer = 'counts', # The layer where the raw counts are stored in adata (None for adata.X: default)
inv_covar_keys = inv_covar_keys,
spur_covar_keys = spur_covar_keys,
latent_dim_inv = 20,
latent_dim_spur = 5,
device = 'cpu',
decoder_dist = 'nb'
)
Set inject_covar_in_latent= True
if you wish to add the spurious conditions directly to the latent (instead of learning the spurious latents). This gives you the most compatible version to SCVI.
For non-factorized model, use:
model = NFinVAE(
adata = adata_train,
layer = 'counts', # The layer where the raw counts are stored in adata (None for adata.X: default)
inv_covar_keys = inv_covar_keys,
spur_covar_keys = spur_covar_keys,
latent_dim_inv = 20,
latent_dim_spur = 5,
device = 'cpu',
decoder_dist = 'nb'
)
- Train the generative model:
model.train(n_epochs=500, lr_train=0.001, weight_decay=0.0001)
- Get the latent representation:
latent = model.get_latent_representation(adata)
- Optional - Train the classifer (for cell types):
model.train_classifier(
adata_val,
batch_key = 'batch',
label_key = 'cell_type',
)
-
Optional - Predict cell types:
pred_test = model.predict(adata_test, dataset_type='test')
-
Optional - Saving and loading model:
model.save('./checkpoints/path.pt')
model.load('./checkpoints/path.pt')
- scanpy==1.9.3
- torch==2.0.1
- tensorboard==2.13.0
- anndata==0.8.0