gao-lab/Cell_BLAST

decoder

bschilder opened this issue · 4 comments

Hi,

Awesome tool! One thing I was having trouble figure out was how to get reconstructed per-gene expression values using a trained model. For other NN models I'd typically do this with predictions = model.predict(some_data), but I can't seem to find the equivalent function for a trained Cell BLAST model. My goal is to use the model as a way of denoising and batch-correcting the data and returning it in high-dimensional space.

Many thanks,
Brian

Hi, thanks for your interest in Cell BLAST!

It is definitely possible to run through the decoder and get batch-corrected high-dimensional data. Actually, we do have an implementation here.

One potential problem is what batch vector to feed to the decoder in this process (see the above script). To get batch-corrected high-dimensional data, the input batch vector should be of the same value for different batches. This can be achieved by using batch vector for any one of the batches, or, an all zero batch vector, an all one batch vector, etc. We did some preliminary experiments which seemed to suggest that all-zero batch vector works best. But since data imputation was not the focus of the project, we did not make it in the final API.

Hope you find it useful!

Perfect, thanks so much for the rapid response.

I just tried it out after modifying it a little bit and it seems to work well! In my case, I was using it on all the brain datasets merged into one ExprDataSet object.

def reconstruct_exprs(model, 
                      dataset,
                      batch_effect="study"):
    data_dict = {"exprs": dataset[:, model.genes].exprs,
                 "library_size": np.array(dataset.exprs.sum(axis=1)).reshape((-1, 1))
    } 
    data_dict[batch_effect] = np.zeros((
        dataset.shape[0],
        np.unique(dataset.obs[batch_effect]).size
    ))
    corrected = model._fetch(model.prob_module.softmax_mu, 
                             cb.utils.DataDict(data_dict))
    adat = cb.data.anndata.AnnData(X=corrected, 
                                   obs=dataset.obs, 
                                   var=model.genes)
    new_dat = cb.data.ExprDataSet.from_anndata(adat)
    return new_dat

reconstructed = reconstruct_exprs(model=model,
                                  dataset=combined_dataset, 
                                  batch_effect="study")

Thanks again,
Brian

That's great! Just one more thing to note is that the decoded data from "softmax_mu" have sum=1 for each cell, so they have a distribution property different from raw count, which may not fit the assumptions for certain downstream analyses if they assume count data. Otherwise everything should be fine.

Let me know if encountered any further issues.

Thanks for the tip, I'll keep that in mind!