We can use Trainer, Evaluator, Extension, and Reporter on PyTorch.
pip install git+https://github.com/Hiroshiba/pytorch-trainer
Please see train_mnist.py that is modifyed from Chainer's train_mnisy.py.
# Train with Trainer
PYTHONPATH='.' python examples/train_mnist.py \
--device cuda \
--autoload \
--epoch 5
The logs from LogReport extension:
epoch main/loss validation/main/loss main/accuracy validation/main/accuracy elapsed_time
0 2.30575 0.0768 1.31841
1 0.213044 0.104855 0.935383 0.9668 14.1628
2 0.0811537 0.0846022 0.97475 0.9737 27.0918
3 0.0535006 0.0839404 0.982833 0.9749 39.4642
4 0.0395484 0.0851855 0.987083 0.9763 51.5603
5 0.0285093 0.0926847 0.990883 0.9726 63.1763
Trainer can be saved everything, ex. Model, Optimizer, Iterator, Reporter, etc.
So using trainer.load_state_dict
, we can resume training!
# Resume with Trainer
PYTHONPATH='.' python examples/train_mnist.py \
--device cuda \
--autoload \
--epoch 10 # start from 5 epoch to 10 epoch
epoch main/loss validation/main/loss main/accuracy validation/main/accuracy elapsed_time
0 2.30575 0.0768 1.31841
1 0.213044 0.104855 0.935383 0.9668 14.1628
2 0.0811537 0.0846022 0.97475 0.9737 27.0918
3 0.0535006 0.0839404 0.982833 0.9749 39.4642
4 0.0395484 0.0851855 0.987083 0.9763 51.5603
5 0.0285093 0.0926847 0.990883 0.9726 63.1763 <-- saved logs
6 0.0257853 0.0679396 0.991567 0.9824 75.748 <-- new logs
7 0.0202872 0.0777352 0.993583 0.9797 85.5424
8 0.0214834 0.0869443 0.9928 0.9778 94.8266
9 0.0166034 0.0999511 0.99435 0.9782 104.533
10 0.0145655 0.0913273 0.995433 0.9801 114.099
- some extensions don't exist (ex. Schedulers or DumpGraph)
- Trainer have Modules (because PyTorch's Optimizer don't contain Network)
serialize
is replaced tostate_dict
andload_state_dict
- cannot use DataLoader (because Trainer needs Iterator, and Iterator needs the length of dataset)
pytorch_trainer/
├── iterators
│ ├── multiprocess_iterator.py
│ ├── order_samplers.py
│ └── serial_iterator.py
├── reporter.py
└── training
├── extensions
│ ├── evaluator.py
│ ├── fail_on_nonnumber.py
│ ├── log_report.py
│ ├── micro_average.py
│ ├── plot_report.py
│ ├── print_report.py
│ ├── progress_bar.py
│ ├── snapshot_writers.py
│ └── value_observation.py
├── trainer.py
├── triggers
│ ├── early_stopping_trigger.py
│ ├── interval_trigger.py
│ ├── manual_schedule_trigger.py
│ ├── minmax_value_trigger.py
│ ├── once_trigger.py
│ └── time_trigger.py
└── updaters
└── standard_updater.py
pytest -s -v tests
- Scheduler
- DataLoader
MIT LICENSE (like Chainer)