a question about the positive sample mask
qchqy opened this issue · 1 comments
qchqy commented
Hi, thanks for your release.
I have read your code and paper. In your paper, the positive sample mask is obtained from the feature after GCN. But in your code, you maybe use the feature before GCN(s_box_features) to obtain the positive sample mask. How to understand
s_box_features = self.roi_heads._shared_roi_transform([features['res4']], [t_proposals[0].proposal_boxes]) #t_proposals[0], results[1]
s_roih_logits = self.roi_heads.box_predictor(s_box_features.mean(dim=[2, 3]))
t_box_features = model_teacher.roi_heads._shared_roi_transform([t_features['res4']], [t_proposals[0].proposal_boxes])
t_roih_logits = model_teacher.roi_heads.box_predictor(t_box_features.mean(dim=[2, 3]))
s_graph_feat = self.GraphCN(s_box_features.mean(dim=[2, 3]))
s_graph_logits = self.roi_heads.box_predictor(s_graph_feat)
t_graph_feat = self.GraphCN(t_box_features.mean(dim=[2, 3]))
t_graph_logits = model_teacher.roi_heads.box_predictor(t_graph_feat)
losses["st_const"] = self.KD_loss(s_roih_logits[0], t_roih_logits[0])
losses["s_graph_const"] = self.KD_loss(s_graph_logits[0], s_roih_logits[0])
losses["t_graph_const"] = self.KD_loss(t_graph_logits[0], t_roih_logits[0])
losses["graph_conloss"] = self.Graph_conloss(t_box_features.mean(dim=[2, 3]), s_box_features.mean(dim=[2, 3]), self.GraphCN)
#s_box_features
#Graph_conloss
def forward(self, t_feat, s_feat, graph_cn, labels=None, mask=None):
qx = graph_cn.graph.wq(s_feat)
kx = graph_cn.graph.wk(s_feat)
sim_mat = qx.matmul(kx.transpose(-1, -2))
dot_mat = sim_mat.detach().clone()
thresh = 0.5
dot_mat -= dot_mat.min(1, keepdim=True)[0]
dot_mat /= dot_mat.max(1, keepdim=True)[0]
mask = ((dot_mat>thresh)*1).detach().clone()
mask.fill_diagonal_(1)
anchor_feat = self.head_1(s_feat)
contrast_feat = self.head_2(s_feat)
anchor_feat = F.normalize(anchor_feat, dim=1)
contrast_feat = F.normalize(contrast_feat, dim=1)
ss_anchor_dot_contrast = torch.div(torch.matmul(anchor_feat, contrast_feat.T), self.temperature) ##### torch.Size([6, 6])
logits_max, _ = torch.max(ss_anchor_dot_contrast, dim=1, keepdim=True) ##### torch.Size([6, 1]) - contains max value along dim=1
ss_graph_logits = ss_anchor_dot_contrast - logits_max.detach()
ss_graph_all_logits = torch.exp(ss_graph_logits)
ss_log_prob = ss_graph_logits - torch.log(ss_graph_all_logits.sum(1, keepdim=True))
ss_mean_log_prob_pos = (mask * ss_log_prob).sum(1) / mask.sum(1)
# loss
ss_loss = - (self.temperature / self.base_temperature) * ss_mean_log_prob_pos
ss_loss = ss_loss.mean()
return ss_loss
Vibashan commented
Hi @qchqy,
Thank you for your interest in our work.
-
To clarify, are you referring specifically to the mask computation in Graph_conloss, where graph_cn.graph.wq(s_feat) and graph_cn.graph.wk(s_feat) are used?
-
Do you see s_feat here as the pre-GCN feature, and is this the source of your concern?