What about upsampling?
chanshing opened this issue · 12 comments
Thank you for your awesome work. I have a question:
What if I want to perform upsampling instead of downsampling? As I understand, aliasing is a problem only when downsampling. But I came across this paper A Style-Based Generator Architecture for Generative Adversarial Networks where they also blur during upsampling, citing your work. Here the blur was applied after upsampling (instead of before as in downsampling). Could you comment on that?
I intend to apply your technique in a VAE or GAN, and I would like to know whether I should include the blur in the decoder/generator.
Thanks for the question! Yes, when upsampling, blurring should be applied after. When downsampling, it should be applied before. In both cases, it is applied at the higher resolution.
To fully understand why, I recommend taking a signal processing course and seeing what these operations are actually doing in the Fourier domain.
I see... The more I think about it the more it makes sense. I will follow your recommendation. Thanks for the quick reply.
May I ask one more question? I read your previous reply related to this,
Thanks for the question!
Yes, upsample with
ConvTranspose2d
withgroups=in_channels
, and then use[1, 1]
or[1, 2, 1]
as your weights.For TensorFlow, perhaps you can refer to the StyleGan2 code/paper, which does smoothed upsampling as well.
I want to check if I understand correctly the Pytorch implementation. For downsampling we do:
Conv(stride=1), ReLU, Conv(stride=k, fixed blur weight)
For upsampling we do:
Conv(stride=1), ReLU, ConvTransposed(stride k, fixed blur weight)
Is this correct?
Yes, that's correct. For upsampling, it's not relevant what happens before the convtranspose layer. Just make sure you blur after upsampling. So I would phrase it as:
Whatever, ConvTransposed(stride k, fixed blur weight)
@richzhang Thank you for clarifying. Please if you don't mind I have two more questions:
-
In practice, would you first upsample then conv+relu, or first conv+relu then upsample? In other words,
ConvTransposed(stride=k, fixed blur weight), Conv(stride=1), ReLU
orConv(stride=1), ReLU, ConvTransposed(stride=k, fixed blur weight)
? Many papers seem to go for the first option, perhaps influenced by https://distill.pub/2016/deconv-checkerboard/ -
You mention in the paper
Note that applying the Rect-2 and Tri-3 filters
while upsampling correspond to “nearest” and “bilinear”
upsampling, respectively.
I think this is true only for strides of 2 (upsampling factor of 2). This is mostly the case anyway so you probably omitted this, but I just want to confirm this as others might use larger factors/stride.
Do Whatever, ConvTransposed(stride k, fixed blur weight), Whatever
. As long as there is blurring directly after upsampling, you are antialiasing.
Yes about the stride
@richzhang - I was trying to see how I could implement upsampling using ConvTranspose2d
and fixed weights, and it appears to me that as far as Pytorch's nn.functional.upsample
method is concerned, it seems to use the [1., 3., 3., 1.]
kernel for 'bilinear'
upsampling rather than the Tri [1., 2., 1.]
kernel as you mention in the paper (and above).
Here is a simple reproducible code block that seems to confirm this for me:
import torch
import numpy as np
# Setup image of zeros with a 2x2 block of 1 in the center
a = np.zeros((6,6))
a[2:4, 2:4] = 1
a_t = torch.from_numpy(a[None, None, :, :])
# Setup filter
filt_1d = np.array([1.,3.,3.,1.])
filt_2d = filt_1d[None,:] * filt_1d[:,None]
# Conv transpose upsampling
filt = 4*torch.from_numpy(filt_2d[None, None, :, :]/filt_2d.sum())
conv =torch.nn.functional.conv_transpose2d(a_t, filt, stride=2, padding=1)
conv_np = conv.numpy().squeeze()
# Pytorch built-in upsamling
upsample = torch.nn.functional.upsample(a_t, scale_factor=2, mode='bilinear')
# Check that they match
upsample_np = upsample.numpy().squeeze()
np.testing.assert_allclose(upsample_np, conv_np) is None
>>> True
Am I doing something wrong here?
Yes, you started with a 2x2 filter a[2:4, 2:4] = 1
. You effectively convolved a [1 1] with a [1 2 1] to get [1 3 3 1]. Do a delta function a[2,2]=1
instead.
Thanks for the quick reply! I am treating a_t
as a poor mockup of the input image to be upsampled. Even if I change the above code such that I set a[2, 2] = 1
instead, I only get a match in results with the [1 3 3 1] filter rather than [1 2 1].
Thanks, got it. I'm not sure if the pytorch implementation is correct. If the signal is properly upsampled, you should be able to do simple subsampling upsample[:,:,::2,::2]
to get your original input back.
Here's the operation for scipy, which is consistent with [1,2,1]
filter.
import scipy.interpolate
a = np.zeros((5,5))
a[2,2] = 1
sample = np.arange(0,5,.5)
print(scipy.interpolate.interp2d(np.arange(5), np.arange(5), a, kind='linear')(sample, sample))
Output
[[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. ]
[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. ]
[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. ]
[0. 0. 0. 0.25 0.5 0.25 0. 0. 0. 0. ]
[0. 0. 0. 0.5 1. 0.5 0. 0. 0. 0. ]
[0. 0. 0. 0.25 0.5 0.25 0. 0. 0. 0. ]
[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. ]
[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. ]
[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. ]
[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. ]]
Okay I understand now.
[0 1 0 0]
should be upsampled to [0 0.5 1.0 0.5 0 0 0 0]
. Torch then shifts over by a half index and reinterpolates to [0 .25 .75 .75 .25 0 0 0]
I wouldn't have done that half-index shift.
Thanks @richardyang . It does appear that Pytorch is doing something weird with its upsampling. Btw, I really enjoyed reading through this paper and the experiments you laid out in it 👍