/pytorch-sequential-helpers

Some helper PyTorch modules that allow complex networks to be expressed as Sequentials

Primary LanguagePythonMIT LicenseMIT

pytorch-sequential-helpers

Some helper modules that allow complex networks (particulary those with parallel data flows) to be expressed as a single Sequential.

Example

Pass same input to 2 different NN branches and merge the results by adding

nn.Sequential(
  Parallel(branch1, branch2),
  Add()
)

Split the RGB channels of a batch of images, pass each to a different NN branch, and concat the results

nn.Sequential(
  Split((1, 1, 1), dim=1),
  Parallel(branch1, branch2, branch3),
  Concat(dim=1)
)