Implementations of data-parallel training in PyTorch. This includes examples of how some distributed operations can be used in PyTorch, a dataloader to load data and scatter minibatches across nodes, and distributed training over multiple nodes/devices, with an option to shard optimizer state across the nodes. Sharding optimizer state allows us to use a larger batch size at the cost of increased communication between nodes. I verified the implementation by varying the batch size to check that at certain batch sizes sharded optimizer state is necessary to avoid CUDA OOM errors.
yulonglin/data_paralellism
Implementation of data-parallel training with sharded optimization state with Nikola Jurkovic
PythonMIT