RuntimeError in GAN training
yuta-tech opened this issue · 2 comments
When training GAN with text, an error happened in function "update_generator_running_avg".
My enviroment is
OS : Ubuntu 18.04.4 LTS
Python : 3.6.12
Pytorch : 1.6.0
CUDA : 10.1
The error detail is below,
Traceback (most recent call last):
File "run_generation.py", line 713, in
update_generator_running_avg(epoch)
File "run_generation.py", line 452, in update_generator_running_avg
param.mul_(alpha).add_(g_state_dict[k], alpha=1-alpha)
RuntimeError: result type Float can't be cast to the desired output type Long
And I found type of "param" is changed float32 to long(int64) by print debug in for loop.
type of param : torch.float32
type of alpha : <class 'float'>
type of param : torch.float32
type of alpha : <class 'float'>
type of param : torch.float32
type of alpha : <class 'float'>
type of param : torch.float32
type of alpha : <class 'float'>
type of param : torch.float32
type of alpha : <class 'float'>
type of param : torch.float32
type of alpha : <class 'float'>
type of param : torch.float32
type of alpha : <class 'float'>
type of param : torch.float32
type of alpha : <class 'float'>
type of param : torch.float32
type of alpha : <class 'float'>
type of param : torch.float32
type of alpha : <class 'float'>
type of param : torch.int64
type of alpha : <class 'float'>
How can I solve this?
Fixed in 27cdbf2!
I confirmed it. And the error was solved!
Thank you for your prompt response.