The scDGN model aims at reducing dimensionality for scRNA data with a kind taking care of the batch effects under supervised setting using conditional adversarial domain training.
- PyTorch
- NumPy
- Scipy
- tqdm
Architecture of scDGN. The network includes three modules: scRNA encoder (blue), label classifier (orange) and domain discriminator (red). Note that the red and orange networks use the same encoding as input. Solid lines represent the forward direction of the neural network while the dashed lines represent the backpropagation direction with the corresponding gradient it passes. Gradient Reversal Layers (GRL) have no effect in forward propagation, but flip the sign of the gradients that flow through them during backpropagation. This allows the combined network to simultaneously optimize label classification and attempt to "fool" the domain discriminator. Thus, the encoder leads to representations that are invariant to the different domains while still distinguishing cell types.
First, create the enviroment with Anaconda. Installing Pytorch with the other versions of CUDA can be found at Pytorch document. Here PyTorch 1.0.1 and CUDA 9.0 are used:
git clone git@github.com:SongweiGe/scDGN.git
cd scDGN
mkdir scDGN/data scDGN/ckpts
conda create -n scDGN python=3.6
conda activate scDGN
conda install tqdm
conda install pytorch=1.0.1 cuda90 -c pytorch
Then download the content of scDGN/
folder from the link https://drive.google.com/drive/folders/1HUUT_g-uPHlLDnZS44CylH1i3DrawAsI?usp=sharing to directory scDGN/data/
.
Change the -o
to whatever you want to call. Also see run_scDGN.py to adjust hyper-parameters (for example, change l
to control the trade-off between domain invariance and classification accuracy, dn
indicates the name of dataset, m
represents margin in the contrastive loss, adv
indicates if we use scDGN or NN, use validation
to do hyperparameters selection with a validation set.).
python run_scDGN.py -dn pancreas1 -l 0.1 -m 1 -g 0 -adv 0 -o pancreas1_test_NN --validation 0
python run_scDGN.py -dn pancreas1 -l 0.1 -m 1 -g 0 -adv 1 -o pancreas1_test_CDGN --validation 0
For inferring and visualizing the representations for NN and scDGN models. Please refer to Jupyter Notebook PCA_visualization.ipynb. Specifically, use plot_pca_all
to plot the full dataset as below.
Note that it would be hard to visualize the whole scquery dataset in one image due to the large number of cell types and batches. You can plot for certain batches and cell types with plot_pca_ct
or plot_pca
.
For selecting the important genes identified by the models Please refer to Jupyter Notebook feature_importance.ipynb.