github-pengge/PyTorch-progressive_growing_of_gans

torch0.3 py=3.6 RuntimeError

oneHuster opened this issue · 2 comments

Traceback (most recent call last):
  File "train.py", line 365, in <module>
    pggan.train()
  File "train.py", line 286, in train
    self.train_phase(R, phase, batch_size, _range[0]*batch_size, _range[0], _range[1])
  File "train.py", line 240, in train_phase
    self.forward_D(cur_level, detach=True)
  File "train.py", line 196, in forward_D
    self.d_real = self.D(self.real, cur_level=cur_level, gdrop_strength=strength)
  File "/home/jurh/anaconda2/envs/th03/lib/python3.6/site-packages/torch/nn/modules/module.py", line 325, in __call__
    result = self.forward(*input, **kwargs)
  File "models/model.py", line 218, in forward
    return self.output_layer(x, y, cur_level, insert_y_at)
  File "/home/jurh/anaconda2/envs/th03/lib/python3.6/site-packages/torch/nn/modules/module.py", line 325, in __call__
    result = self.forward(*input, **kwargs)
  File "models/base_model.py", line 280, in forward
    x = self.chain[max_level](x)
  File "/home/jurh/anaconda2/envs/th03/lib/python3.6/site-packages/torch/nn/modules/module.py", line 325, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/jurh/anaconda2/envs/th03/lib/python3.6/site-packages/torch/nn/modules/container.py", line 67, in forward
    input = module(input)
  File "/home/jurh/anaconda2/envs/th03/lib/python3.6/site-packages/torch/nn/modules/module.py", line 325, in __call__
    result = self.forward(*input, **kwargs)
  File "models/base_model.py", line 74, in forward
    vals = torch.mean(vals, keepdim=True)
RuntimeError: mean() missing 1 required positional arguments: "dim"

@oneHuster
you can fix that by changing line 74 in models/base_model.py to the following:

 vals = torch.mean(vals, dim=1, keepdim=True)

thx