/pytorch-tensor-decompositions

PyTorch implementation of [1412.6553] and [1511.06530] tensor decomposition methods for convolutional layers.

Primary LanguagePython

PyTorch Tensor Decompositions

This is an implementation of Tucker and CP decomposition of convolutional layers. A blog post about this can be found here.

It depends on TensorLy for performing tensor decompositions.

Usage

  • Train a model based on fine tuning VGG16: python main.py --train.

  • There should be a dataset with two categories. One directory for each category. Training data should go into a directory called 'train'. Testing data should go into a directory called 'test'. This can be controlled with the flags --train_path and --test_path.

  • I used the Kaggle Cats/Dogs dataset.

  • The model is then saved into a file called "model".

  • Perform a decomposition: python main.py --decompose This saves the new model into "decomposed_model". It uses the Tucker decomposition by default. To use CP decomposition, pass --cp.

  • Fine tune the decomposed model: python main.py --fine_tune

References