code error
Opened this issue · 2 comments
swjtulinxi commented
def forward(self, t1, t2):
B, C, H, W = t1.size()
t1 = t1.view(B, C, 1, H, W)
t2 = t2.view(B, C, 1, H, W)
t1_to_t2 = torch.cat([t1, t2], dim=2) ####
t2_to_t1 = torch.cat([t1, t2], dim=2) ######## should be torch.cat([t2, t1], dim=2)
diff_1 = self.conv_context_t1_to_t2(t1_to_t2).view(B, C, H, W)
diff_2 = self.conv_context_t2_to_t1(t2_to_t1).view(B, C, H, W)
diff = diff_1 + diff_2
guanyuezhen commented
def forward(self, t1, t2): B, C, H, W = t1.size() t1 = t1.view(B, C, 1, H, W) t2 = t2.view(B, C, 1, H, W) t1_to_t2 = torch.cat([t1, t2], dim=2) #### t2_to_t1 = torch.cat([t1, t2], dim=2) ######## should be torch.cat([t2, t1], dim=2) diff_1 = self.conv_context_t1_to_t2(t1_to_t2).view(B, C, H, W) diff_2 = self.conv_context_t2_to_t1(t2_to_t1).view(B, C, H, W) diff = diff_1 + diff_2
Thanks a lot for pointing out this error.
hexiao0275 commented
Very good author, looking forward to your next work!