guanyuezhen/AR-CDNet

code error

Opened this issue · 2 comments

def forward(self, t1, t2):
    B, C, H, W = t1.size()
    t1 = t1.view(B, C, 1, H, W)
    t2 = t2.view(B, C, 1, H, W)
    t1_to_t2 = torch.cat([t1, t2], dim=2)    ####  
    t2_to_t1 = torch.cat([t1, t2], dim=2)    ########    should be     torch.cat([t2, t1], dim=2)
    diff_1 = self.conv_context_t1_to_t2(t1_to_t2).view(B, C, H, W)
    diff_2 = self.conv_context_t2_to_t1(t2_to_t1).view(B, C, H, W)
    diff = diff_1 + diff_2
def forward(self, t1, t2):
    B, C, H, W = t1.size()
    t1 = t1.view(B, C, 1, H, W)
    t2 = t2.view(B, C, 1, H, W)
    t1_to_t2 = torch.cat([t1, t2], dim=2)    ####  
    t2_to_t1 = torch.cat([t1, t2], dim=2)    ########    should be     torch.cat([t2, t1], dim=2)
    diff_1 = self.conv_context_t1_to_t2(t1_to_t2).view(B, C, H, W)
    diff_2 = self.conv_context_t2_to_t1(t2_to_t1).view(B, C, H, W)
    diff = diff_1 + diff_2

Thanks a lot for pointing out this error.

Very good author, looking forward to your next work!