Computer Science Communication

  • Student Name: Haozhen Shen
  • Student #: 1003115112

Sharing

Comfortable sharing through any source.

Image Super Resolution Using SRGAN

The problem

Super-resolution is the process of improving and refining the details within an image. Usually we take a low-resolution image as input and aim to output the same image but in higher resolution and with more refined details. However, the refined details are often unknown. This ill-posed problem is in general very challenging sice we could have multiple hight resolution images as the solution to one low resolution image. In the context of this blog post we focus on single image super resolution (SISR).

We state the problem more formally. Given a degradation process, $D(I, \delta)$ where $I$ is the image and $\delta$ our parameters. Our LR image $I_x$ is obtained from degrading our HR image $I_y$.

$$I_{x} = D(I_y, \delta)$$

Where the degradation process $D$ is generally unknown. We aim to find the SR model $F$,

$${\hat{I}}_y=F(I_x, \theta)$$

Where $\theta$ is parameter for our SR model $F$. As a result, the objective of Image super-resolution becomes,

$$\hat{\theta}=\underset{\theta}{\mathrm{argmin}}{L({\hat{I}}_y,\ I_y)+\lambda\phi(\theta)}$$

where $L(\hat{I}_y, I_y)$ represents the loss function between the generated HR image $\hat{I}_y$ and the ground truth image $ I_y$ , and $\phi(\theta)$ is the regularization term and $\lambda$ is the tradeoff parameter.

This is a complex image processing task, which requires us to reconstruct the corresponding high-resolution images from the observed low-resolution images. In this blog we will consider a generative approach based on deep learning to tackle this problem.

GANS

We briefly review the concepts of Generative Adversarial Networks, namely GANs. GAN is a generative model where deep learning is commoly incorporated in it's architecture. The GAN model architecture involves two sub-models, a generator model for generating new samples and a discriminator model for classifying whether a sample is real or fake. The two sub-models is trained adversarially as playing a zero sum game.

alt text

Samples generated from the generator along with real samples, are provided to the discriminator to classified as real or fake. The discriminator is then updated to get better at discriminating real and fake samples, on the other hand, the generator is updated to get better at fooling the discriminator.

SRGAN

Our goal is to train a generating function G that estimates for a given LR input image its corresponding HR counterpart. To achieve this, we train a generator network as a feed-forward CNN. $$G_{\theta_G}$$ parametrized by $$\theta_G$$ and minimize with respect to,

$$\hat{\theta}=\underset{\theta}{\mathrm{argmin}}{L(G_{\theta_G}(I_x),\ I_y)+\lambda\phi(\theta)}$$

Thus in our context of GAN. The input of the generator is a Low-Resolution (LR) image, which it then generates a fake high-resolution image. The Discriminator takes two sets of images, one is a set of fake high-resolution images generated by the generator, and another is a set of corresponding real high-resolution images and then learns to distinguish fake or real for each given image set by outputting a value between 0 and 1, where closer to 1 represents the original high-resolution image, and closer to 0 represents the fake high-resolution image. Our discriminator is trained in an alternating manner base on the solving the adversarial min-max problem.

alt text

Network Architecture

alt text alt text

Perceptual Loss Function

The perceptual loss as the weighted sum of a content loss (lSRX) and an adversarial loss component:

alt text

We describe the components of content loss as follows. Note the instead of the widely used pixel-wise MSE loss,

alt text

We take a step further and make use of the VGG loss based on the ReLU activation layers of the pre-trained 19 layers VGG network described in Simonyan and Zisserman,

alt text

Where $$W_{i,j}$$ and $$H_{i,j}$$ represents the dimensions of the respective feature maps within the VGG network. For our network to favor in the generator to fool the discriminator we add in the generative component of our GAN to the perceptual loss. Our generative loss $$I_{SR}^{GEN}$$ is defined based on the probabilities of the discriminator $$D_{\theta_D}(G_{\theta_G}(I^{LR}))$$ over all training samples.

alt text

A sample pytroch implementation below,

class FeatureExtractor(nn.Module):
    def __init__(self):
        super(FeatureExtractor, self).__init__()
        vgg19_model = vgg19(pretrained=True)
        self.feature_extractor = nn.Sequential(
            *list(vgg19_model.features.children())[:18]
            )
    def forward(self, img):
        return self.feature_extractor(img)


class ResidualBlock(nn.Module):
    def __init__(self, input_dim):
        super(ResidualBlock, self).__init__()
        self.conv_block = nn.Sequential(
            nn.Conv2d(input_dim, input_dim, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(input_dim, 0.8),
            nn.PReLU(),
            nn.Conv2d(input_dim, input_dim, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(input_dim, 0.8),
        )

    def forward(self, x):
        # Skip connection
        return x + self.conv_block(x)

class upsampleBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(upsampleBlock, self).__init__()
        self.upsample = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, stride=1, padding=1),
            nn.PixelShuffle(2),
            nn.PReLU()
        )
    def forward(self, img):
        return self.upsample(img)
        
class Generator(nn.Module):
    def __init__(self, upsample_factor):
        n_residual_blocks = 16
        super(Generator, self).__init__()
        self.n_residual_blocks = n_residual_blocks
        self.upsample_factor = upsample_factor

        self.conv1 = nn.Conv2d(3, 64, 9, stride=1, padding=4)
        res_blocks = []
        for _ in range(n_residual_blocks):
            res_blocks.append(ResidualBlock(64))
        self.res_blocks = nn.Sequential(*res_blocks)
        
        self.conv2 = nn.Conv2d(64, 64, 3, stride=1, padding=1)
        self.bn2 = nn.BatchNorm2d(64)

        up_blocks = []
        for _ in range(int(self.upsample_factor/2)):
            up_blocks.append(upsampleBlock(64, 256))
        self.upsample_blocks = nn.Sequential(*up_blocks )

        self.conv3 = nn.Conv2d(64, 3, 9, stride=1, padding=4)

    def forward(self, img):
        first_conv = self.conv1(img)

        first_conv_out = first_conv.clone()
        res_net_out = self.res_blocks(first_conv_out)
        conv_3_in = self.bn2(self.conv2(res_net_out)) + first_conv_out

        up_sample_out = self.upsample_blocks(conv_3_in)
        conv_three_out = self.conv3(up_sample_out)
        return conv_three_out


class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(3, 64, 3, stride=1, padding=1),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(64, 64, 3, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(64, 128, 3, stride=1, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(128, 128, 3, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(128, 256, 3, stride=1, padding=1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(256, 256, 3, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(256, 512, 3, stride=1, padding=1),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(512, 512, 3, stride=2, padding=1),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(512, 1, 3, stride=1, padding=1)
        )
    def forward(self, img):      
        x = self.model(img)
        return torch.sigmoid(F.avg_pool2d(x, x.size()[2:])).view(
            x.size()[0], -1)

Empircal Results

We trained all networks on Colab GPU using the DIV2L data set of 800 images. We obtained the LR images by downsampling the HR images (BGR, C = 3) using bicubic kernel with downsampling factor x4. For each mini-batch we crop 4 random 256 x 256 HR sub images of distinct training images. Note that we can apply the generator model to images of arbitrary size as it is fully convolutional. We scaled the range of the LR input images and HR images to according to the same normalization parameters of the Vgg19 net namely mean = [0.485, 0.456, 0.406] and std = [0.229, 0.224, 0.225]. MSE loss was thus calculated on images of the same intensity range as the VGG feature maps. To obtain VGG losses of a scale that is comparable to the MSE loss. This is equivalent to multiplying Vgg loss with a rescaling factor of 0.006. For optimization we use Adam [18] with \beta_1= 0.9. The SRResNet networks were trained with a learning rate of 10-4 with 3 epochs. We employed the trained MSE-based SRResNet network as initialization to avoid undesired local optima. We also initialize target vectors for real high-resolution image a random number between 0.7 to 1 and for fake high-resolution image a random number between 0 to 0.3. All SRGAN variants were trained with 200 epochs at a learning rate of 10-4 and another 40 epochs at a lower rate of 10-5. We alternate updates to the generator and discriminator network.

Some generated samples after training of 4300 and 4350 batches,

alt text

alt text

References

[1] Ledig, C., Theis, L., Huszar, F., Caballero, J., Cunningham, A., Acosta, A., . . . Shi, W. (2017, May 25). Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network. Retrieved December 18, 2020, from https://arxiv.org/abs/1609.04802

[2] DIVerse 2K resolution high quality images as used for the challenges @ NTIRE (CVPR 2017 and CVPR 2018) and @ PIRM (ECCV 2018). (n.d.). Retrieved December 18, 2020, from https://data.vision.ee.ethz.ch/cvl/DIV2K/