/SimplePruning

cnn pruning with tensorflow.

Primary LanguagePythonMIT LicenseMIT

SimplePruning

This repository provides a cnn channels pruning demo with tensorflow. You can pruning your own model(support conv2d,depthwise conv2d,pool,fc,concat, add ops and so on) defined in modelsets.py. Have a good time!

 author Haibo     contributions welcome


Dependencies

  • Tensorflow >= 1.10.0
  • python >= 3.5
  • opencv-python >= 4.1.0
  • numpy >= 1.14.5
  • matplotlib >= 3.0.3

Getting Started

  • Clone the repository
  $ git clone https://github.com/DasudaRunner/SimplePruning.git
  • Downdload the Cifar10 dataset, and put into cifar-10-python/

    Url: http://www.cs.toronto.edu/~kriz/cifar.html

  • (Optional) Define your model in modesets.py

    You must use add_layer() API defined in pruner.py to set up your model. More details to modelsets.py

  • (Optional) Config params in utils/config.py

    e.g. model name, learning rate, pruning rate.

  • Train a full model, .ckpt and .pb model file will be saved in ckpt_model/
  $ python full_train.py
  • Channel pruning. .ckpt and .pb model file will be saved in channels_pruned_model/
  $ python channels_pruning.py

Supported ops (Tensorflow)

  • Conv2d
  • FullyConnected
  • MaxPooling, AveragePooling
  • BatchNormalization
  • Activation
  • DepthwiseConv2d
  • GlobalMaxPooling, GlobalAveragePooling
  • Concat
  • Add
  • Flatten

Evaluation on Cifar10 dataset

Model Dataset Pruning rate Model size / MB Inference time / ms*64pic
SimpleNet cifar-10 0.5 8.7 -> 1.8 5.8 -> 2.7
VGG19 cifar-10 0.5 53.4 -> 13.5 28.62 -> 9.44
DenseNet40 cifar-10 0.5 4.3 -> 1.5 77.87 -> 39.97
MobileNet V1 cifar-10 0.5 6.6 -> 1.8 19.39 -> 8.01
OCR-Net --- 0.5 2426.2 -> 841.9 10.36->7.3

Update logs

  • 2019.07.24
    • Add support for Add op.
    • Add support for ResNet18/ResNet34 in modelsets.py.
  • 2019.07.16
    • Add support for Concat op.
    • Add support for DenseNet40 in modelsets.py.
  • 2019.07.14
    • Reconsitution SimplePruning.