kimiyoung/ssl_bad_gan

Hi,after I replace the code ,please help

Opened this issue · 2 comments

Hi,
It is the problem of tensor dimension.
Just replace the code
norm_weight = self.weight * (weight_scale[None,:,None,None] / torch.sqrt((self.weight ** 2).sum(3).sum(2).sum(0) + 1e-6)).expand_as(self.weight)
with
norm_weight = self.weight * (weight_scale[None,:,None,None] / torch.sqrt((self.weight ** 2).sum(3).sum(2).sum(0) + 1e-6).reshape([-1, 1, 1, 1])).expand_as(self.weight)

Originally posted by @zcwang0702 in #6 (comment)

but after I replace the code
new error:
RuntimeError: The expanded size of the tensor (128) must match the existing size (3) at non-singleton dimension 0

original code error :
File "/home/gis/PycharmProjects/guo/ssl_bad_gan-master/model.py", line 165, in forward
norm_weight = self.weight * (weight_scale[None,:,None,None] / torch.sqrt((self.weight ** 2).sum(3).sum(2).sum(0) + 1e-6)).expand_as(self.weight)
RuntimeError: The expanded size of the tensor (5) must match the existing size (3) at non-singleton dimension 3
Thank you!

Hi,
It is the problem of tensor dimension.
Just replace the code
norm_weight = self.weight * (weight_scale[None,:,None,None] / torch.sqrt((self.weight ** 2).sum(3).sum(2).sum(0) + 1e-6)).expand_as(self.weight)
with
norm_weight = self.weight * (weight_scale[None,:,None,None] / torch.sqrt((self.weight ** 2).sum(3).sum(2).sum(0) + 1e-6).reshape([-1, 1, 1, 1])).expand_as(self.weight)

Originally posted by @zcwang0702 in #6 (comment)

but after I replace the code
new error:
RuntimeError: The expanded size of the tensor (128) must match the existing size (3) at non-singleton dimension 0

original code error :
File "/home/gis/PycharmProjects/guo/ssl_bad_gan-master/model.py", line 165, in forward
norm_weight = self.weight * (weight_scale[None,:,None,None] / torch.sqrt((self.weight ** 2).sum(3).sum(2).sum(0) + 1e-6)).expand_as(self.weight)
RuntimeError: The expanded size of the tensor (5) must match the existing size (3) at non-singleton dimension 3
Thank you!
try
norm_weight = self.weight * (weight_scale[None,:,None,None] / torch.sqrt((self.weight ** 2).sum(3).sum(2).sum(0) + 1e-6).reshape([1, -1, 1, 1])).expand_as(self.weight)