/pixelcnn-pytorch

A collection of PixelCNN models implemented in PyTorch

Primary LanguagePython

Inspiration

Based on Pixel Recurrent Neural Networks by van den Oord et. al. This by no means serve to reproduce the original results in the paper.

PixelCNN

PixelCNNs are a type of autoregressive generative models which try to model the generation of images as a sequence of generation of pixels. More formally, PixelCNN model the joint distribution of pixels over an image x as the following product of conditional distributions, where xi is a single pixel:

The ordering of the pixel dependencies is in raster scan order: row by row and pixel by pixel within every row. Every pixel therefore depends on all the pixels above and to the left of it, and not on any other pixels. We see this setup in other autoregressive models such as MADE. The difference lies in the way the conditional distributions are constructed. With PixelCNN every conditional distribution is modelled by a CNN using masked convolutions.

The left figure visualizes how the PixelCNN maps a neighborhood of pixels to prediction for the next pixel. To generate pixel xi the model can only condition on the previously generated pixels x1, ..., xi-1. This conditioning is done by masking the convolutional filters, as shown in the right figure. This is a type A mask, in contrast to type B mask where the weight for the middle pixel also is set to 1.

PixelCNN models

Regular PixelCNN

This model followes a simple PixelCNN architecture to model binary MNIST and shapes images. It has the following architecture:

  • A 7×7 masked type A convolution
  • 5 7×7 masked type B convolutions
  • 2 1×1 masked type B convolutions
  • Appropriate ReLU nonlinearities and Batch Normalization in-between
  • 64 convolutional filters

PixelCNN with independent color channels (PixelRCNN)

This model supports RGB color channels, but models the color channels independently. More formally, we model the following parameterized distribution:

Trained on color Shapes and color MNIST. It uses the following architecture:

  • A 7×7 masked type A convolution
  • 8 residual blocks with masked type B convolutions
  • Appropriate ReLU nonlinearities and Batch Normalization in-between
  • 128 convolutional filters

PixelCNN with dependent color channels (Autoregressive PixelRCNN)

This PixelCNN models dependent color channels. This is done by changing the masking scheme for the center pixel. The filters are split into 3 groups, only allowing each group to see the groups before (or including the current group, for type B masks) to maintain the autoregressive property. More formally, we model the parameterized distribution:

For computing a prediction for pixel xi in channel R we only use previous pixels x<i in channel R (mask type A). Then, when predicting pixel xi in the G channel we use the previous pixels x<i in both G and R, but since we at this time also have a prediction for xi in the R channel, we may use this as well (mask type B). Similarly, when predicting xi in channel B, we can use previous pixels for all channels, along with current pixel xi for channel R and G. This way, the predictions are now dependent on colored channels.

Figure above shows the difference between type A and type B mask. The 'context' refers to all the previous pixels (x<i).

Conditional PixelCNNs

This PixelCNN is class-conditional on binary MNIST and binary Shapes. Formally, we model the conditional distribution:

Class labels are conditioned on by adding a conditional bias in each convolutional layer. More precisely, in the th convolutional layer, we compute

,

where is a masked convolution (as in the previous models), V is a 2D weight matrix, and y is a one-hot encoding of the class label (where the conditional bias is broadcasted spacially and added channel-wise). Uses a similar architecture as the regular PixelCNN.

Datasets

The four datasets used:

Binary Shapes Binary MNIST Colored Shapes Colored MNIST

Generated samples from the models

Below are samples generated by the different PixelCNN models after training.

PixelCNN

Binary Shapes Binary MNIST

PixelRCNN

Colored Shapes Colored MNIST

Autoregressive PixelRCNN

Colored Shapes Colored MNIST

Conditional PixelCNN

Binary Shapes Binary MNIST