This repo provides a PyTorch implementation of the Dr. Frankenstein model stitching framework used to perform the experiments described in the paper Similarity and Matching of Neural Network Representations.
- We stitch trained neural networks together to study the similarity of their latent representations. (This technique was first used by [Lenc and Vedaldi, 2015].) Our starting point is the observation that stitching can be done with minor performance loss when the two networks share the same training datasets and network topologies.
- We study the relation between this functional similarity and previously studied geometric notions of representational similarity such as CKA and CCA. We observe that their behavior can be counter-intuitive from the perspective of functional similarity.
- We experiment with restricting the space of transformations in the stitching layer (to e.g., orthogonal, low rank, sparse) to investigate how information is organised in the latent representations.
conda env create -f environment.yml
pip install -r requirements.txt
Note: PyTorch is not included in requirements, you need to install that manually. We used version 1.8 in our experiments.
If you prefer docker, you can also pull a prebuilt container:
cd container
./pull.sh
cd ..
After you can run any command inside container like:
cd container
./docker_run.sh -g 0 -c "python python_file.py"
- -g Which GPU to use for training, list GPU IDs here
- -c the command to run inside the container
The config file needs to be set properly.
As you can see in docer_run.sh the /home/${USER}/cache
folder is mapped to
the /cache
folder inside, so it is recommended to store your data in
/home/${USER}/cache/data/pytorch
and /home/${USER}/cache/data/celeba
folders and leave the config file
with the default settings.
You must create the data folder outside the container like:
mkdir /home/${USER}/cache
There's a config file which tells the script where it can find or download the datasets to. Please edit config/default.env
:
[dataset_root]
pytorch = '/cache/data/pytorch' # path to pytroch datasets such as cifar10
celeba = '/cache/data/celeba' # path to celeba dataset
You can train your own networks as below. Some pretrained models are uploaded to model_weights/ folder.
Example:
python train_model.py -m tiny10 -d cifar10
python stitch_nets.py path/to/model1.pt /path/to/model2.pt layer1 layer2 -d dataset
An example stitch with pretrained models.
python stitch_nets.py model_weights/Tiny10/CIFAR10/in0-gn0/110000.pt model_weights/Tiny10/CIFAR10/in10-gn10/110000.pt bn3 bn3 -d cifar10 -i ps_inv
- -h, --help Get help about parameters
- --run-name The name of the subfolder to save to. If not given, it defaults to the current date-time.
- -e, --epochs Number of epochs to train. Default: 30
- -lr, --lr Learning rate. Default: 1e-3
- -o, --out-dir Folder to save networks to. Default: snapshots/
- -b, --batch-size Batch size. Default: 128
- -s, --save-frequency How often to save the transformation matrix in iterations. This number is multiplied by the number of epochs. Default: 10
- --seed Seed of the run. Default: 0
- -wd, --weight-decay Weight decay to use. Deault: 1e-4
- --optimizer Name of the optimizer. Please choose from: adam, sgd. Default: adam
- --debug Either to run in debug mode or not. Default: False
- --flatten Either to flatten layers around transformation. NOTE: not used in the paper, hardly ever used, it might be buggy. Default: False
- --l1 l1 regularization used on transformation matrix. Default: 0
- --cka-reg CKA regularisation used on transformation matrix. Default: 0
- -r, --low-rank Maximum rank of matrix. Use max rank by default.
- -i, --init Initialisation of transformation matrix. Options:
- random: random initialisation. Default.
- perm: random permutation
- eye: identity matrix
- ps_inv: pseudo inverse initialisation
- ones-zeros: weight matrix is all 1, bias is all 0.
- -m, --mask Any mask applied on transformation. Options:
- identity: All values are 1 in mask. Default.
- semi-match: Based on correlation choose the best pairs.
- abs-semi-match: Semi-match between absolute correlations.
- random-permuation: A random permutation matrix.
- --target-type The loss to apply at logits. Options:
- hard: Use true labels. Default.
- soft_1: Use soft crossentropy loss to model1.
- soft_2: Use soft crossentropy loss to model2.
- soft_12: Use soft crossentropy loss to the mean of model1 and model2.
- soft_1_plus_2: Use soft crossentropy loss to the sum of model1 and model2.
- --temperature The temperature to use if target type is a soft label. Default: 1.
You will find the results of your runs under results/ folder by default, and a pickle file that contains all information about your run. E.g. the bias & weights of the stitching layer, accuracy, crossentropy, etc.
Print layer information of the architecture, one can stitch between the printed layers
python layer_info.py model_name
Handled model_names: lenet, tiny10, nbn_tiny10, nbntiny10, dense, inceptionv1, resnet20_w*
Example:
python layer_info.py resnet20_w3
python compare_frank_m2.py path/to/file.pkl stitch_type measure1 measure2 measure3 ..
Stitching types:
- before - initial state of transformation matrix before training
- after - trained transformation matrix
- ps_inv - use pseudo inverse transformation (calculated on validation set)
Example:
python compare_frank_m2.py results/stitching_result.pkl after cka
python compare_nets.py path/to/model1.pt /path/to/model2.pt layer1 layer2 dataset method1 method2 method3 ..
Example:
python compare_nets.py model_weights/Tiny10/CIFAR10/in0-gn0/110000.pt model_weights/Tiny10/CIFAR10/in10-gn10/110000.pt bn5 bn5 cifar10 cka l2
To evaluate a trained network:
python eval_net.py path/to/model.pt
Example:
python eval_net.py model_weights/Tiny10/CIFAR10/in0-gn0/110000.pt
To evaluate a stitched network:
python eval_stitch.py results/stitching_result.pkl stitch_type
Stitching types:
- before - initial state of transformation matrix before training
- after - trained transformation matrix
- ps_inv - use pseudo inverse transformation (calculated on validation set)