gdewael/cpg-transformer

Cell embeddings

Closed this issue · 1 comments

Hi @gdewael,

Could you please tell me how the cell embeddings are calculated?

thanks

Hi! Cell embeddings are randomly initialized from a gaussian N(0, 1) via PyTorch's nn.Embedding. Each cell in a dataset is embedded to a 64-dimensional random vector. Functionally, this is equivalent to making a one-hot encoding for every cell and then embedding that in a 64-dimensional space with a linear layer. The embeddings themselves are directly optimized during training. This way, the model learns to differentiate between different cells and give each cell its own identity.

When applying pre-trained models on new datasets, the cell embeddings from the previous dataset are removed and new cell embeddings are trained. (i.e. cell embeddings cannot be pre-trained as we will assume you will apply models on new cells).