This repo contains training scripts to train and evaluate popular deep learning classification models: Lenet, Alexnet and VGG using Brevitas.
- Pytorch >= 1.1.0
- Brevitas (https://github.com/Xilinx/brevitas)
After installing Pytorch, install Brevitas:
git clone https://github.com/Xilinx/brevitas
cd brevitas
pip install .
Clone this repo
git clone https://github.com/MinahilRaza/Brevitas_Fixed_Point.git
This repo includes pretrained models with the following accuracies
Name | Dataset | Weight quantization | Activation quantization | Brevitas Top1 |
---|---|---|---|---|
Lenet | MNIST | 8 bit | 8 bit | 98.99% |
Lenet | MNIST | 8 bit for conv only | 8 bit for conv only | 99.08% |
Alexnet | CIFAR10 | 8 bit | 8 bit | 85.42% |
VGG | CIFAR10 | 8 bit | 8 bit | 86.22% |
In order to launch training, run the following commands
From within the training_scripts folder:
python run.py --network Lenet --dataset MNIST
From within the training_scripts folder:
python run.py --network AlexNet --dataset CIFAR10
From within the training_scripts folder:
python run.py --network VGG --dataset CIFAR10
In order ro resume training from saved checkpoint use --resume flag. Set its value to true for resuming training. For evaluation on validation set, set --evaluate flag to true.
This README format was inspired by brevitas_cnv_lfc (https://github.com/ussamazahid96/brevitas_cnv_lfc)