Based on Pytorch0.4.0
It is mainly used to implement pruning research for now
However, we will try to make it more general to handel different jobs such as Object Detection, Semantic Segmentation and so on. (such a long way to go)
Sorry, I've just reconstructed my codebase and I haven't had time to rewrite the ReadMe yet.
- main.py
- utils/
- trainer.py
- plugins.py
- convert_DataParallel_Model.py
- getClassesFromOfficialDataset.py
- measure.py
- nets/
- prune.py
- prune_resnet.py
This is the main function of the codebase in which we use parser to parse the command arguments and then initial the datasets, models, optimizers & criterions and finally call trainer to train, validate & save checkpoint as well as moded_parameters. Futheremore, we will add more comments soon after.
This is a folder includes some useful functions, among which the trainer.py is the most important. Because the most implementations of deep learning training process are defined in it. We will provide more details in the README soon after.
-
The fundamental implement of training process which includes training, evaluation(validation), checkpoint saver(save), checkpoint loader(resume) and so on. We wrapped them into a Trainer class which consists of a few attributes such as self.model, self.optimizer, self.dataloader(train_loader and validation_loader) etc. and some method such as self.start(the only method you need to call outside), self.train, self.validate etc. The main idea is to make the training process as flexible and low coupling as possible and thus the key core is the self.plugins which allow users customize the process. The self.plugins consists of five lists 'iteration', 'epoch', 'batch', 'update', 'dataforward'
- iteration, typically called after the output of model. For example, we usually use this kind of plugin to accuracy and then log them out.
- epoch, usually called after an epoch has done. For example, we usually use it to write the data into tensorboardX
- batch, It can be literally explained as a preprocess before we put the batch data into model.
- update, well, this kind of plugin is mainly used to update the model during we backward the loss and then we use optimizer.step() to really update. For example, we use it to implement LASSO in model.
- dataforward, literally speaking, you just need to define one to decide how to lead the data flow and calculate the loss. For example, we apply it to implement Knowledge Distillation which requires the two forward procedures for teacher model and student model while we also have to calculation the MSEloss between them.
-
You will basically understand what's inside, won't you? Take a look, bro.+
-
Due to the different serialized storages of PyTorch's normal module and dataparallel module, we have to convert the serialized storage between them. This is a function to convert dataparallel module to a normal module.
-
It is glad to see the latest version of torchvision had added class(labels) names to the official torchvision datasets such as cifar10, cifar100 etc. though the latest version hasn't released. Therefore, I had implemented a function to get those label names according to the latest source code of torchvision.dataset.cifar. It is merely used in printing the class accuracy(check the code in plugins.py)
-
Well, this contains two functions which are designed to measure the total amount of parameters and FLOPs of specified model.
This is a folder includes some our own models. We will provide more details in the README soon after.
This is the python code to prune the plain CNN such as VGG.
This is the python code to prune the non-plain CNN such ResNet. We will soon provide code to prune DenseNet.
-
elaborate the utils module - add more and normative comments to the .py (incoming as long as I'm not busy)
- implement a theoretical Quantilization and then convert it to a practical one
- Make the codebase more conform to a design mode.
- To be continued