/pytorch-revnet

Implementation of the reversible residual network in pytorch

Primary LanguagePython

revnet

PyTorch implementation of the reversible residual network.

Requirements

The main requirement ist obviously PyTorch. CUDA is strongly recommended.

The training script requires tqdm for the progress bar.

The unittests require the TestCase implemented by the PyTorch project. The module can be downloaded here.

Note

The revnet models in this project tend to have exploding gradients. To counteract this, I used gradient norm clipping. For the experiments below you would call the following command:

python train_cifar.py --model revnet38 --clip 0.25

Results

CIFAR-10

Model Accuracy Memory Usage Params
resnet32 92.02% 1271 MB 0.47 M
revnet38 91.98% 660 MB 0.47 M