/AnoGAN-pytorch

Pytorch implementation of "Unsupervised Anomaly Detection with Generative Adversarial Networks to Guide Marker Discovery"

Primary LanguagePython

AnoGAN-pytorch

Pytorch implementation of "Unsupervised Anomaly Detection with Generative Adversarial Networks to Guide Marker Discovery".
Official Paper: https://arxiv.org/pdf/1703.05921.pdf

What is AnoGAN?

anogan AnoGAN is a deep convolutional generative adversarial network to learn a manifold of normal anatomical variability, accompanying a novel anomaly scoring scheme based on the mapping from image space to a latent space.
AnoGAN
This paper aims to find a point z in the latent space that corresponds to an image G(z) that is visually most similar to query image x and that is located on the manifold X.

Brief Summary of AnoGAN

  1. Train DCGAN with solely on image data of healthy cases with the aim to model the variety of healthy appearance.
  2. Map new images to the latent space and find the most similar image G(z) via applying backpropagation iteratively.
  3. Compute Anomaly score A(x) which is a weighted sum of residual loss and discrimination loss.

Loss Functions of AnoGAN : Anomaly_Score = (1-lambda) * Residual Loss + lambda * Discrimination_Loss

loss The Total Loss to find latent variable z is defined by a weighted sum of Residual Loss and Discrimination Loss.

  • Residual Loss : Measures L1 distance of the visual similarity between qeury image and generated image in the image space.

  • Discrimination Loss : Measures L1 distance of the feature similarity between qeury image and generated image in the feature representation of the discriminator.(Enforces the generated image to lie on the manifold by using the trained discriminator not as classifier, but as a feature extractor)

def residual_loss(x,G_z):
  return torch.sum(torch.abs(x - G_z))

def discrimination_loss(x,z,D,G):
  feature_G_z , _ = D(G(z))
  feature_x , _ = D(x)
  return torch.sum(torch.abs( feature_x - feature_G_z))
  
def anomaly_score(r_loss, d_loss , __lambda__ = 0.1):
  return (1 - __lambda__) * r_loss + __lambda__ * d_loss

References