moothes/A2S-v2

Doubts about the training environment

Ayews opened this issue · 4 comments

Ayews commented

Very impressive work. Unfortunately, we encountered some problems during reproducing.

Specifically, when using DUTS-TR to train RGB SOD tasks, the evaluation results obtained were significantly different from those described in the paper. However, when using the pre training weights you provided, the results were aligned with those described in the paper. We suspect that this is caused by different training environments.

We hope you can provide the training environment for use, such as Python=3?, Pytorch=1?, to help us complete the reproduction work.

Thanks very much.

Thanks for your attention!
Sorry I've moved to a new university, so the previous server is no longer accessible to me. As I recall it may be python 3.5 + pytorch 1.10.
I also train the code on my current server (Python3.9 + PyTorch1.13), and it gets (0.916, 0.943, 0.039) (mean-F, E-measure, MAE). It is very similar to (0.917, 0.945, 0.038) reported in our paper.
Therefore, I don't think the version of packages are influential for the results.
Can you provide more details about your re-production (e.g., training log) so I can find out the reason?

Ayews commented

Thank you for your reply! We have identified the issue, which was the incorrect loading of the pretrained weights (TAT). However, we still have two questions that we would like to seek your advice on.

  1. The design of the Boundary-aware Texture Matching (BTM) loss is very clever, and we believe that it can effectively search for relevant information in the input image while avoiding interference from irrelevant information. It provides a good idea for unsupervised SOD tasks. Unfortunately, it might be a bit difficult to understand, so we have added some comments to the code implementation:
# Boundary-aware Texture Matching Loss
def BTMLoss(pred, image, radius, config=None):
        alpha = config['rgb']
        modal = config['trset'] # 'c' for RGB
        num_modal = len(modal) if 'c' in modal else len(modal)+1
        
        slices = range(0, 3*num_modal+1, 3)
        sal_map =  F.interpolate(pred, scale_factor=0.25, mode='bilinear', align_corners=True)
        image_ = F.interpolate(image, size=sal_map.shape[-2:], mode='bilinear', align_corners=True)
        mask = get_contour(sal_map)
        features = torch.cat([image_, sal_map], dim=1) # mask is in the last column of features 
        
        N, C, H, W = features.shape
        diameter = 2 * radius + 1 # sample range. The values referred to by diameter and radius in Figure 5 of the original text are 3 and 1, respectively. 
        # For each prediction point in the mask and image, take (diameter^2-1) points around it as candidate points, and calculate the distance between the center point and the candidate points. 
        kernels = F.unfold(features, diameter, 1, radius).view(N, C, diameter, diameter, H, W) 
        kernels = kernels - kernels[:, :, radius, radius, :, :].view(N, C, 1, 1, H, W)
        dis_modal = 1
        for idx, slice in enumerate(slices):
            # Traverse each modality information. 
            if idx == len(slices) - 1:
                continue
            # When the center point in the image is close to the candidate point, The corresponding value in the dis_map is close to 1; On the contrary, the corresponding value in the dis_map is close to 0. 
            # Therefore, loss only takes effect when the center point in the mask deviates from the candidate point and the center point in the corresponding position in the image tends to be consistent with the candidate point. 
            dis_map = (-alpha * kernels[:, slice:slices[idx+1]] ** 2).sum(dim=1, keepdim=True).exp()
            # Only RGB
            if config['only_rgb'] and idx > 0:
                dis_map = dis_map * 0 + 1
            dis_modal = dis_modal * dis_map
            
        dis_sal = torch.abs(kernels[:, slices[-1]:]) # Take the mask from the last column 
        distance = dis_modal * dis_sal

        loss = distance.view(N, 1, (radius * 2 + 1) ** 2, H, W).sum(dim=2)
        loss = torch.sum(loss * mask) / torch.sum(mask)
        return loss

Based on our understanding, the BTM loss only takes effect when the center point in the mask deviates from the candidate point and the center point in the corresponding position in the image tends to be consistent with the candidate point. This allows the loss to effectively avoid interference. However, it seems somewhat conservative. We wonder if the BTM loss can also be applied to candidate points that are more consistent with the center point. We have made some attempts and experiments, but we did not observe significant improvements. Therefore, we would like to ask for your guidance on this matter.

  1. We believe that your work is very pioneering. Can we further improve on your work and use a similar naming convention like 'A2S-V3' when publishing?

We would appreciate your response and look forward to hearing from you!

I'm glad that you've resolved the re-production issue.
As for your questions:

  1. I'm not sure I understand your question correctly. Actually, our BTM loss takes effect on the boundary pixels of current saliency predictions, which are dynamic during training. Our motivation is to measure whether current saliency boundaries are consistent as in RGB or other modalities (depth map, optical flow, or thermal image). If not, the current saliency boundary are considered to be wrong. By this way, the boundaries of saliency can be optimized during the training.

  2. I would be grateful to witness the lasting impact my work has made on the SOD task. You are welcomed to use 'A2S-V3' if your work will be published in top-tier (CCF-A or well-recognized, e.g., ECCV) journals or conferences. My starting point is to maintain the good reputation of this series and every work in the series contributes significantly to the development of the SOD task.

Ayews commented

Thanks! My doubts have been well answered. Closing the issue.