Pocket is a fairly lightweight library built on the pupular PyTorch framework. The library provides utilities aimed at lowering the barriers to entry when it comes to training deep neural networks. For most deep learning applications, the relevant code can be divided into three categories: model, dataloader and training script. Existing frameworks have already provided ample resources for popular models and datasets, yet lack highly encapsulated and flexible training utilities. Pocket is designed to fill this gap.
Pocket provides a range of engine classes that can perform training and testing with minimum amount of code. The following is a simple demo.
Pocket provides two base classes of engines: pocket.core.LearningEngine and pocket.core.DistributedLearningEngine with the following features:
- CPU/GPU training
- Multi-GPU (distributed) training
- Automatic checkpoint saving
- Elaborate training log
To accomodate distinct training scenarios, the learning engines are implemented with maximum flexibility, and with the following structure
self._on_start() # Invoked prior to all epochs
for ... # Iterating over all epochs
self._on_start_epoch() # Invoked prior to each epoch
for ... # Iterating over all mini-batches
self._on_start_iteration() # Invoked prior to each iteration
self._on_each_iteration() # Foward, backward pass etc.
self._on_end_iteration() # Invoked after each iteration
self._on_end_epoch() # Invoked after each epoch
self._on_end() # Invoked after all epochs
For details and inheritance of the base learning engines, refer to the documentation. For practical examples, refer to the following
pocket.core.MultiClassClassificationEngine | mnist |
pocket.core.MultiLabelClassificationEngine | voc2012, hico |
pocket.core.DistributedLearningEngine | mnist |
Anaconda/miniconda is recommended for environment management. Follow the steps below to install the library.
# Create conda environment (python>=3.5)
conda create --name pocket python=3.8
conda activate pocket
# Install pytorch>=1.5.1
conda install pytorch torchvision cudatoolkit=10.2 -c pytorch
pip install matplotlib tqdm scipy
# Install Pocket under any desired directory
git clone https://github.com/fredzzhang/pocket.git
pip install -e pocket
# Run an example as a test (optional)
cd pocket/examples
CUDA_VISIBLE_DEVICES=0 python mnist.py