Potential errors in code
cats-food opened this issue · 1 comments
cats-food commented
@crashmoon Thanks for your great work! I got some issues about your code: in your file gan_lstm.py , line 338-356 are listed as follows:
class Net_G(nn.Module):
def __init__(self):
super(Net_G, self).__init__()
self.unet_1 = Unet(3, Param.unet_channel * 8)
self.unet_2 = Unet(6, Param.unet_channel * 8 * 3)
self.unet_3 = Unet(6, Param.unet_channel * 8 * 3)
self.unet_4 = Unet(6, Param.unet_channel * 8 * 3)
self.rnn = nn.LSTMCell(Param.unet_channel * 8, Param.unet_channel * 8 * 2)
def forward(self, data_1, data_2, data_3, data_4, h0, c0):
#print(data_1.size())
unet_out_1, unet_mid_1 = self.unet_1(data_1)
h1, c1 = self.rnn(unet_mid_1.view(Param.batch_size, -1), (h0, c0))
unet_out_2, unet_mid_2 = self.unet_2(torch.cat((data_1, unet_out_1), 1), h1) # check
h2, c2 = self.rnn(unet_mid_2.view(Param.batch_size, -1), (h1, c1))
unet_out_3, unet_mid_3 = self.unet_3(torch.cat((data_1, unet_out_2), 1), h2) # check
h3, c3 = self.rnn(unet_mid_3.view(Param.batch_size, -1), (h2, c2))
unet_out_4, unet_mid_4 = self.unet_4(torch.cat((data_1, unet_out_3), 1), h3) # check
return unet_out_1, unet_out_2, unet_out_3, unet_out_4
In my opinion, the arguments passed into self.unet_2, self.unet_3, self.unet_4
(I've appended with '# check') should be data_2, data_3, data_4
respectively. Could you please help me check out with this, and correct me if I am wrong, thanks!
youarenotaloneor commented
My Gan_ lstm. Py can running, but it doesn't produce results. Why???, It has been running for an hour, and there is no error report and no result