/Trainer

🐸 - A general purpose model trainer, as flexible as it gets

Primary LanguagePython

👟 Trainer

An opinionated general purpose model trainer on PyTorch with a simple code base.

Installation

From Github:

git clone https://github.com/coqui-ai/Trainer
cd Trainer
make install

From PyPI:

pip install trainer

Prefer installing from Github as it is more stable.

Implementing a model

Subclass and overload the functions in the TrainerModel()

Training a model

See the test script here training a basic MNIST model.

Training with DDP

$ python -m trainer.distribute --script path/to/your/train.py --gpus "0,1"

We don't use .spawn() to initiate multi-gpu training since it causes certain limitations.

  • Everything must the pickable.
  • .spawn() trains the model in subprocesses and the model in the main process is not updated.
  • DataLoader with N processes gets really slow when the N is large.

Profiling example

  • Create the torch profiler as you like and pass it to the trainer.
    import torch
    profiler = torch.profiler.profile(
        activities=[
            torch.profiler.ProfilerActivity.CPU,
            torch.profiler.ProfilerActivity.CUDA,
        ],
        schedule=torch.profiler.schedule(wait=1, warmup=1, active=3, repeat=2),
        on_trace_ready=torch.profiler.tensorboard_trace_handler("./profiler/"),
        record_shapes=True,
        profile_memory=True,
        with_stack=True,
    )
    prof = trainer.profile_fit(profiler, epochs=1, small_run=64)
    then run Tensorboard
  • Run the tensorboard.
    tensorboard --logdir="./profiler/"

Supported Experiment Loggers

To add a new logger, you must subclass BaseDashboardLogger and overload its functions.