kimiyoung/ssl_bad_gan

expand_as error

Opened this issue · 0 comments

I cloned the repo and made the modifications to run with python 3.6. When running mnist_train.py, I get this error
RuntimeError: The expanded size of the tensor (3) must match the existing size (160) at non-singleton dimension 3. Target sizes: [160, 2, 2, 3]. Tensor sizes: [160, 1, 1, 160]

on
norm_weight = self.weight * (weight_scale[:,None,None,None] / torch.sqrt((self.weight ** 2).sum(3).sum(2).sum(1) + 1e-6)).expand_as(self.weight)

I checked all the dimensions leading to this statement
self.weight.shape = [160, 2, 2, 3]
weight_scale[:,None,None,None] shape = [160, 1, 1, 1]
the torch.sqrt term shape = [160]
but the division came out to
weight_scale[:,None,None,None] / torch.sqrt((self.weight ** 2).sum(3).sum(2).sum(1) + 1e-6) shape = [160, 1, 1, 160]

This I am quite puzzled. Why would a shape [160, 1, 1, 1] divided by shape [160] lead to [160, 1, 1, 160]?

This shape, in term, fails to expand_as to [160, 2, 2, 3], which is probably reasonable.