Vibashan/irg-sfda

a question about the positive sample mask

qchqy opened this issue · 1 comments

qchqy commented

Hi, thanks for your release.
I have read your code and paper. In your paper, the positive sample mask is obtained from the feature after GCN. But in your code, you maybe use the feature before GCN(s_box_features) to obtain the positive sample mask. How to understand
image

    s_box_features = self.roi_heads._shared_roi_transform([features['res4']], [t_proposals[0].proposal_boxes]) #t_proposals[0], results[1]
    s_roih_logits = self.roi_heads.box_predictor(s_box_features.mean(dim=[2, 3]))

    t_box_features = model_teacher.roi_heads._shared_roi_transform([t_features['res4']], [t_proposals[0].proposal_boxes])
    t_roih_logits = model_teacher.roi_heads.box_predictor(t_box_features.mean(dim=[2, 3]))

    s_graph_feat = self.GraphCN(s_box_features.mean(dim=[2, 3]))
    s_graph_logits = self.roi_heads.box_predictor(s_graph_feat)
    
    t_graph_feat = self.GraphCN(t_box_features.mean(dim=[2, 3]))
    t_graph_logits = model_teacher.roi_heads.box_predictor(t_graph_feat)

    losses["st_const"] = self.KD_loss(s_roih_logits[0], t_roih_logits[0]) 
    losses["s_graph_const"] = self.KD_loss(s_graph_logits[0], s_roih_logits[0]) 
    losses["t_graph_const"] = self.KD_loss(t_graph_logits[0], t_roih_logits[0]) 
    losses["graph_conloss"] = self.Graph_conloss(t_box_features.mean(dim=[2, 3]), s_box_features.mean(dim=[2, 3]), self.GraphCN)
    #s_box_features 
#Graph_conloss
def forward(self, t_feat, s_feat, graph_cn, labels=None, mask=None):    

        qx = graph_cn.graph.wq(s_feat)
        kx = graph_cn.graph.wk(s_feat)        
        sim_mat = qx.matmul(kx.transpose(-1, -2))
        dot_mat = sim_mat.detach().clone()

        thresh = 0.5
        dot_mat -= dot_mat.min(1, keepdim=True)[0]
        dot_mat /= dot_mat.max(1, keepdim=True)[0]
        mask = ((dot_mat>thresh)*1).detach().clone()
        mask.fill_diagonal_(1)

        anchor_feat = self.head_1(s_feat)
        contrast_feat = self.head_2(s_feat)

        anchor_feat = F.normalize(anchor_feat, dim=1)
        contrast_feat = F.normalize(contrast_feat, dim=1)

        ss_anchor_dot_contrast = torch.div(torch.matmul(anchor_feat, contrast_feat.T), self.temperature)  ##### torch.Size([6, 6])
        logits_max, _ = torch.max(ss_anchor_dot_contrast, dim=1, keepdim=True)  ##### torch.Size([6, 1]) - contains max value along dim=1
        ss_graph_logits = ss_anchor_dot_contrast - logits_max.detach()

        ss_graph_all_logits = torch.exp(ss_graph_logits)
        ss_log_prob = ss_graph_logits - torch.log(ss_graph_all_logits.sum(1, keepdim=True))
        ss_mean_log_prob_pos = (mask * ss_log_prob).sum(1) / mask.sum(1)  
    
        # loss
        ss_loss = - (self.temperature / self.base_temperature) * ss_mean_log_prob_pos
        ss_loss = ss_loss.mean()

        return ss_loss

Hi @qchqy,

Thank you for your interest in our work.

  1. To clarify, are you referring specifically to the mask computation in Graph_conloss, where graph_cn.graph.wq(s_feat) and graph_cn.graph.wk(s_feat) are used?

  2. Do you see s_feat here as the pre-GCN feature, and is this the source of your concern?