/SalGAN

PyTorch implementation of SalGAN

Primary LanguagePythonMIT LicenseMIT

SalGAN

This is a clean implementation of the paper SalGAN: Visual Saliency Prediction with Generative Adversarial Networks in PyTorch.

The source code of SalGAN is publicly available[1], but it is written in Theano. There is also a PyTorch implementation of SalGAN[2], but many who tried to run the code report the mismatch of adversarial loss function between the original paper and their implementation. Here, in this implementation, the adversarial loss function is the same as stated in the original SalGAN paper.

The generative adversarial networks consist of two components:

  • an encoder-decoder generator for saliency prediction and
  • a discriminator for ground truth and generated saliency distinction.

The generator is a standard U-shape net, identical to the architecture of VGG-16 (encoder) followed by its reversed version (decoder).

The discrimnator is a smaller convolutional network with three fully connected layers attached to the end. The architecture of the system, the loss functions, and optimizers are all consistent with the SalGAN paper.

Note that the final pooling and fully connected layers is removed from VGG16.


If there is any problem with the code, please let me know or create an issue.


1: https://github.com/imatge-upc/salgan

2: https://github.com/batsa003/salgan1