crashmoon/Progressive-Generative-Networks

Potential errors in code

cats-food opened this issue · 1 comments

@crashmoon Thanks for your great work! I got some issues about your code: in your file gan_lstm.py , line 338-356 are listed as follows:

class Net_G(nn.Module):
    def __init__(self):
        super(Net_G, self).__init__()
        self.unet_1 = Unet(3, Param.unet_channel * 8)
        self.unet_2 = Unet(6, Param.unet_channel * 8 * 3)
        self.unet_3 = Unet(6, Param.unet_channel * 8 * 3)
        self.unet_4 = Unet(6, Param.unet_channel * 8 * 3)
        self.rnn = nn.LSTMCell(Param.unet_channel * 8, Param.unet_channel * 8 * 2)

    def forward(self, data_1, data_2, data_3, data_4, h0, c0):
        #print(data_1.size())
        unet_out_1, unet_mid_1 = self.unet_1(data_1)
        h1, c1 = self.rnn(unet_mid_1.view(Param.batch_size, -1), (h0, c0))
        unet_out_2, unet_mid_2 = self.unet_2(torch.cat((data_1, unet_out_1), 1), h1)        # check
        h2, c2 = self.rnn(unet_mid_2.view(Param.batch_size, -1), (h1, c1))
        unet_out_3, unet_mid_3 = self.unet_3(torch.cat((data_1, unet_out_2), 1), h2)        # check
        h3, c3 = self.rnn(unet_mid_3.view(Param.batch_size, -1), (h2, c2))
        unet_out_4, unet_mid_4 = self.unet_4(torch.cat((data_1, unet_out_3), 1), h3)        # check
        return unet_out_1, unet_out_2, unet_out_3, unet_out_4

In my opinion, the arguments passed into self.unet_2, self.unet_3, self.unet_4 (I've appended with '# check') should be data_2, data_3, data_4 respectively. Could you please help me check out with this, and correct me if I am wrong, thanks!

My Gan_ lstm. Py can running, but it doesn't produce results. Why???, It has been running for an hour, and there is no error report and no result