/pytorch-trainer

PyTorch's Trainer like Chainer's Trainer

Primary LanguagePythonMIT LicenseMIT

PyTorch's Trainer like Chainer's Trainer

We can use Trainer, Evaluator, Extension, and Reporter on PyTorch.

Install

pip install git+https://github.com/Hiroshiba/pytorch-trainer

Example

Please see train_mnist.py that is modifyed from Chainer's train_mnisy.py.

# Train with Trainer
PYTHONPATH='.' python examples/train_mnist.py \
  --device cuda \
  --autoload \
  --epoch 5

The logs from LogReport extension:

epoch       main/loss   validation/main/loss  main/accuracy  validation/main/accuracy  elapsed_time
0                       2.30575                              0.0768                    1.31841
1           0.213044    0.104855              0.935383       0.9668                    14.1628
2           0.0811537   0.0846022             0.97475        0.9737                    27.0918
3           0.0535006   0.0839404             0.982833       0.9749                    39.4642
4           0.0395484   0.0851855             0.987083       0.9763                    51.5603
5           0.0285093   0.0926847             0.990883       0.9726                    63.1763

Trainer can be saved everything, ex. Model, Optimizer, Iterator, Reporter, etc. So using trainer.load_state_dict, we can resume training!

# Resume with Trainer
PYTHONPATH='.' python examples/train_mnist.py \
  --device cuda \
  --autoload \
  --epoch 10  # start from 5 epoch to 10 epoch
epoch       main/loss   validation/main/loss  main/accuracy  validation/main/accuracy  elapsed_time
0                       2.30575                              0.0768                    1.31841
1           0.213044    0.104855              0.935383       0.9668                    14.1628
2           0.0811537   0.0846022             0.97475        0.9737                    27.0918
3           0.0535006   0.0839404             0.982833       0.9749                    39.4642
4           0.0395484   0.0851855             0.987083       0.9763                    51.5603
5           0.0285093   0.0926847             0.990883       0.9726                    63.1763      <-- saved logs 
6           0.0257853   0.0679396             0.991567       0.9824                    75.748       <-- new logs
7           0.0202872   0.0777352             0.993583       0.9797                    85.5424
8           0.0214834   0.0869443             0.9928         0.9778                    94.8266
9           0.0166034   0.0999511             0.99435        0.9782                    104.533
10          0.0145655   0.0913273             0.995433       0.9801                    114.099

Difference from Chainer.Trainer

  • some extensions don't exist (ex. Schedulers or DumpGraph)
  • Trainer have Modules (because PyTorch's Optimizer don't contain Network)
  • serialize is replaced to state_dict and load_state_dict

Difference from Vanilla PyTorch

  • cannot use DataLoader (because Trainer needs Iterator, and Iterator needs the length of dataset)

Supported Classes

pytorch_trainer/
├── iterators
│   ├── multiprocess_iterator.py
│   ├── order_samplers.py
│   └── serial_iterator.py
├── reporter.py
└── training
    ├── extensions
    │   ├── evaluator.py
    │   ├── fail_on_nonnumber.py
    │   ├── log_report.py
    │   ├── micro_average.py
    │   ├── plot_report.py
    │   ├── print_report.py
    │   ├── progress_bar.py
    │   ├── snapshot_writers.py
    │   └── value_observation.py
    ├── trainer.py
    ├── triggers
    │   ├── early_stopping_trigger.py
    │   ├── interval_trigger.py
    │   ├── manual_schedule_trigger.py
    │   ├── minmax_value_trigger.py
    │   ├── once_trigger.py
    │   └── time_trigger.py
    └── updaters
        └── standard_updater.py

Test

pytest -s -v tests

TODO

  • Scheduler
  • DataLoader

License

MIT LICENSE (like Chainer)