/pytorch-generative

Easy generative modeling in PyTorch.

Primary LanguageJupyter NotebookMIT LicenseMIT

pytorch-generative

pytorch-generative is a Python library which makes generative modeling in PyTorch easier by providing:

To get started, click on one of the links below.

Installation

To install pytorch-generative, clone the repository and install the requirements:

git clone https://www.github.com/EugenHota/pytorch-generative
cd pytorch-generative
pip install -r requirements.txt

After installation, run the tests to sanity check that everything works:

python -m unittest discover

Reproducing Results

All our models implement a reproduce function with all the hyperparameters necessary to reproduce the results listed in the supported algorithms section. This makes it very easy to reproduce any results using our training script, for example:

python train.py --model image_gpt --use-cuda

To run the model on a different dataset, with different hyperparameters, etc, simply modify its reproduce function and rerun with the command above.

Example - ImageGPT

Supported models are implemented as PyTorch Modules and are easy to use:

from pytorch_generative import models

... # Data loading code.

model = models.ImageGPT(in_channels=1, out_channels=1, in_size=28)
model(train_loader)

Alternatively, lower level building blocks in pytorch_generative.nn can be used to write models from scratch. We show how to implement a convolutional ImageGPT model below:

from torch import nn

from pytorch_generative import nn as pg_nn


class TransformerBlock(nn.Module):
  """An ImageGPT Transformer block."""

  def __init__(self, 
               n_channels, 
               n_attention_heads):
    """Initializes a new TransformerBlock instance.
    
    Args:
      n_channels: The number of input and output channels.
      n_attention_heads: The number of attention heads to use.
    """
    super().__init__()
    self._ln1 = pg_nn.NCHWLayerNorm(n_channels)
    self._ln2 = pg_nn.NCHWLayerNorm(n_channels)
    self._attn = pg_nn.CausalAttention(
        in_channels=n_channels,
        embed_channels=n_channels,
        out_channels=n_channels,
        n_heads=n_attention_heads,
        mask_center=False)
    self._out = nn.Sequential(
        nn.Conv2d(
            in_channels=n_channels, 
            out_channels=4*n_channels, 
            kernel_size=1),
        nn.GELU(),
        nn.Conv2d(
            in_channels=4*n_channels, 
            out_channels=n_channels, 
            kernel_size=1))

  def forward(self, x):
    x = x + self._attn(self._ln1(x))
    return x + self._out(self._ln2(x))


class ImageGPT(nn.Module):
  """The ImageGPT Model."""
  
  def __init__(self,       
               in_channels,
               out_channels,
               in_size,
               n_transformer_blocks=8,
               n_attention_heads=4,
               n_embedding_channels=16):
    """Initializes a new ImageGPT instance.
    
    Args:
      in_channels: The number of input channels.
      out_channels: The number of output channels.
      in_size: Size of the input images. Used to create positional encodings.
      n_transformer_blocks: Number of TransformerBlocks to use.
      n_attention_heads: Number of attention heads to use.
      n_embedding_channels: Number of attention embedding channels to use.
    """
    super().__init__()
    self._pos = nn.Parameter(torch.zeros(1, in_channels, in_size, in_size))
    self._input = pg_nn.CausalConv2d(
        mask_center=True,
        in_channels=in_channels,
        out_channels=n_embedding_channels,
        kernel_size=3,
        padding=1)
    self._transformer = nn.Sequential(
        *[TransformerBlock(n_channels=n_embedding_channels,
                         n_attention_heads=n_attention_heads)
          for _ in range(n_transformer_blocks)])
    self._ln = pg_nn.NCHWLayerNorm(n_embedding_channels)
    self._out = nn.Conv2d(in_channels=n_embedding_channels,
                          out_channels=out_channels,
                          kernel_size=1)

  def forward(self, x):
    x = self._input(x + self._pos)
    x = self._transformer(x)
    x = self._ln(x)
    return self._out(x)

Supported Algorithms

pytorch-generative supports the following algorithms.

We train likelihood based models on dynamically Binarized MNIST and report the log likelihood in the tables below.

Autoregressive Models

Algorithm Binarized MNIST (nats) Links
PixelSNAIL 78.61 Code, Paper
ImageGPT 79.17 Code, Paper
Gated PixelCNN 81.50 Code, Paper
PixelCNN 81.45 Code, Paper
MADE 84.87 Code, Paper
NADE 85.65 Code, Paper

Variational Autoencoders

NOTE: The results below are the (variational) lower bound on the log likelihod.

Algorithm Binarized MNIST (nats) Links
VD-VAE <= 80.72 Code, Paper
VAE <= 86.77 Code, Paper
VQ-VAE N/A Code, Paper
VQ-VAE-2 N/A Code, Paper

Neural Style Transfer

Blog: https://towardsdatascience.com/how-to-get-beautiful-results-with-neural-style-transfer-75d0c05d6489
Notebook: https://github.com/EugenHotaj/pytorch-generative/blob/master/notebooks/style_transfer.ipynb
Paper: https://arxiv.org/abs/1508.06576

Compositional Pattern Producing Networks

Notebook: https://github.com/EugenHotaj/pytorch-generative/blob/master/notebooks/cppn.ipynb
Background: https://en.wikipedia.org/wiki/Compositional_pattern-producing_network