The batch dim in dct() function
Opened this issue · 1 comments
kitaev-chen commented
I feel confused about the implement of dct function.
The first problem is about the batch dim N = x_shape[-1]. Whether it is 2D or 3D data which has dim format (N, L, D) or (N, C, H, W), isn’t the batch size N = x_shape[0] ?
def dct(x, norm=None):
Discrete Cosine Transform, Type II (a.k.a. the DCT)
For the meaning of the parameter `norm`, see:
:param x: the input signal
:param norm: the normalization, None or 'ortho'
:return: the DCT-II of the signal over the last dimension
x_shape = x.shape
N = x_shape[-1] # <==============
x = x.contiguous().view(-1, N)
Another question is about the DCTBlur for image data. The image data is 3D which has dim (C, H, W). Why you use 2D instead of 3D?
class DCTBlur(nn.Module):
def __init__(self, blur_sigmas, image_size, device):
super(DCTBlur, self).__init__()
self.blur_sigmas = torch.tensor(blur_sigmas).to(device)
freqs = np.pi*torch.linspace(0, image_size-1,
self.frequencies_squared = freqs[:, None]**2 + freqs[None, :]**2
def forward(self, x, fwd_steps):
if len(x.shape) == 4:
sigmas = self.blur_sigmas[fwd_steps][:, None, None, None]
elif len(x.shape) == 3:
sigmas = self.blur_sigmas[fwd_steps][:, None, None]
t = sigmas**2/2
dct_coefs = torch_dct.dct_2d(x, norm='ortho') # <==============
dct_coefs = dct_coefs * torch.exp(- self.frequencies_squared * t)
return torch_dct.idct_2d(dct_coefs, norm='ortho')
KokeCacao commented
I suppose:
- N does not represent batch in this case. It is the dimension to do dct on
- MNIST is grey scale therefore image is 2D