/pocket

A deep learning library to enable rapid prototyping

Primary LanguagePythonBSD 3-Clause "New" or "Revised" LicenseBSD-3-Clause


doraemon
Pocket

A Deep Learning Library to Enable Rapid Prototyping

Introduction

Pocket is a fairly lightweight library built on the pupular PyTorch framework. The library provides utilities aimed at lowering the barriers to entry when it comes to training deep neural networks. For most deep learning applications, the relevant code can be divided into three categories: model, dataloader and training script. Existing frameworks have already provided ample resources for popular models and datasets, yet lack highly encapsulated and flexible training utilities. Pocket is designed to fill this gap.

Key Features

Pocket provides a range of engine classes that can perform training and testing with minimum amount of code. The following is a simple demo.

demo

Pocket provides two base classes of engines: pocket.core.LearningEngine and pocket.core.DistributedLearningEngine with the following features:

  • CPU/GPU training
  • Multi-GPU (distributed) training
  • Automatic checkpoint saving
  • Elaborate training log

To accomodate distinct training scenarios, the learning engines are implemented with maximum flexibility, and with the following structure

self._on_start()                        # Invoked prior to all epochs
for ...                                 # Iterating over all epochs
    self._on_start_epoch()              # Invoked prior to each epoch
    for ...                             # Iterating over all mini-batches
        self._on_start_iteration()      # Invoked prior to each iteration
        self._on_each_iteration()       # Foward, backward pass etc.
        self._on_end_iteration()        # Invoked after each iteration
    self._on_end_epoch()                # Invoked after each epoch
self._on_end()                          # Invoked after all epochs

For details and inheritance of the base learning engines, refer to the documentation. For practical examples, refer to the following

pocket.core.MultiClassClassificationEngine mnist
pocket.core.MultiLabelClassificationEngine voc2012, hico
pocket.core.DistributedLearningEngine mnist

Installation

Anaconda/miniconda is recommended for environment management. Follow the steps below to install the library.

# Create conda environment (python>=3.5)
conda create --name pocket python=3.8
conda activate pocket
# Install pytorch>=1.5.1
conda install pytorch torchvision cudatoolkit=10.2 -c pytorch
pip install matplotlib tqdm scipy
# Install Pocket under any desired directory
git clone https://github.com/fredzzhang/pocket.git
pip install -e pocket
# Run an example as a test (optional)
cd pocket/examples
CUDA_VISIBLE_DEVICES=0 python mnist.py

License

BSD-3-Clause License