/torchgpipe

A GPipe implementation in PyTorch

Primary LanguagePythonApache License 2.0Apache-2.0

torchgpipe

PyPI Build Status Coverage Status Documentation Status Korean README

A GPipe implementation in PyTorch.

How to use

Prerequisites are:

  • Python 3.6+
  • PyTorch 1.0+
  • Your nn.Sequential module

Install via PyPI:

$ pip install torchgpipe

Wrap your nn.Sequential module with torchgpipe.GPipe. You have to specify balance to partition the module. Then you can specify the number of micro-batches with chunks:

from torchgpipe import GPipe

model = nn.Sequential(a, b, c, d)
model = GPipe(model, balance=[1, 1, 1, 1], chunks=8)

for input in data_loader:
    output = model(input)

This project is still under development. Any public API would be changed without deprecation warnings until v0.1.0.