gmftbyGMFTBY/SimpleReDial-v1

Question about interaction layer

minji2744 opened this issue · 3 comments

I am trying to find how interaction layer is implemented.
But I had no success.. could you help me with interaction layer?

I searched for it in model > InteractionModels
and there seems to be no code for model architecture.
I read your paper and found out that the interaction layer composes of transformer decoder.
How can I implement this?

Sorry for the late response, the codes are in here:

python

from model.utils import *

class BERTDualSCMSmallEncoder(nn.Module):

def __init__(self, **args):
    super(BERTDualSCMSmallEncoder, self).__init__()
    model = args['pretrained_model']
    self.ctx_encoder = BertEmbedding(model=model)
    self.can_encoder = BertEmbedding(model=model)
    # decoder layer
    decoder_layer = nn.TransformerDecoderLayer(d_model=768, nhead=args['nhead'])
    self.fusion_encoder = nn.TransformerDecoder(decoder_layer, num_layers=args['num_layers'])
    # sequeeze and gate
    self.squeeze = nn.Sequential(
        nn.Dropout(p=args['dropout']) ,
        nn.Linear(768*2, 768)
    )
    self.gate = nn.Sequential(
        nn.Dropout(p=args['dropout']) ,
        nn.Linear(768*3, 768)
    )
    self.args = args
    self.convert = nn.Sequential(
        nn.Dropout(p=args['dropout']),
        nn.Linear(768, 768),
        nn.Tanh(),
        nn.Dropout(p=args['dropout']),
        nn.Linear(768, 768),
    )

def _encode(self, cid, rid, cid_mask, rid_mask, test=False):
    # cid_rep_whole: [B_c, S, E]
    cid_rep_whole = self.ctx_encoder(cid, cid_mask, hidden=True)
    # cid_rep: [B_c, E]
    cid_rep = cid_rep_whole[:, 0, :]
    # cid_rep_: [B_c, 1, E]
    cid_rep_ = cid_rep_whole[:, 0, :].unsqueeze(1)
    # rid_rep: [B_r, E]
    rid_rep = self.can_encoder(rid, rid_mask)

    cid_rep_mt, rid_rep_mt = self.convert(cid_rep), self.convert(rid_rep)

    dps = []
    turn_size = len(cid) if test else self.args['small_turn_size']
    for i_b in range(0, len(cid), turn_size):
        cid_rep_p = cid_rep[i_b:i_b+turn_size]
        rid_rep_p = rid_rep[i_b:i_b+turn_size]
        cid_rep_whole_p = cid_rep_whole[i_b:i_b+turn_size, :, :]
        cid_mask_p = cid_mask[i_b:i_b+turn_size, :]
        cid_rep_p_ = cid_rep_[i_b:i_b+turn_size, :, :]
        rid_size, cid_size = len(rid_rep_p), len(cid_rep_p)
        # cid_rep: [B_r, B_c, E]
        cid_rep_p = cid_rep_p.unsqueeze(0).expand(rid_size, -1, -1)
        # rid_rep: [B_r, B_c, E]
        rid_rep_p = rid_rep_p.unsqueeze(1).expand(-1, cid_size, -1)
        # rep: [B_r, B_c, 2*E]
        rep = torch.cat([cid_rep_p, rid_rep_p], dim=-1)
        # rep: [B_r, B_c, E]
        rep = self.squeeze(rep)    

        # cid_rep_whole: [S, B_c, E]
        cid_rep_whole = cid_rep_whole_p.permute(1, 0, 2)
        # rest: [B_r, B_c, E]
        rest = self.fusion_encoder(
            rep, 
            cid_rep_whole_p,
            memory_key_padding_mask=~cid_mask_p.to(torch.bool),
        )

        ## gate
        # gate: [B_r, B_c, E]
        gate = torch.sigmoid(
            self.gate(
                torch.cat([
                    rid_rep_p,
                    cid_rep_p,
                    rest,
                ], dim=-1) 
            )        
        )
        # rest: [B_r, B_c, E]
        rest = gate * rid_rep_p + (1 - gate) * rest
        # rest: [B_c, E, B_r]
        rest = rest.permute(1, 2, 0)
        # dp: [B_c, B_r]
        cid_rep_p_ = F.normalize(cid_rep_p_, dim=-1)
        rest = F.normalize(rest, dim=-1)
        dp = torch.bmm(cid_rep_p_, rest).squeeze(1)
        dps.append(dp)
    return dps, cid_rep_mt, rid_rep_mt

@torch.no_grad()
def predict(self, batch):
    cid = batch['ids']
    cid_mask = torch.ones_like(cid)
    rid = batch['rids']
    rid_mask = batch['rids_mask']
    dp, _, _ = self._encode(cid, rid, cid_mask, rid_mask)    # [1, 10]
    return dp[0].squeeze()

def forward(self, batch):
    cid = batch['ids']
    rid = batch['rids']
    cid_mask = batch['ids_mask']
    rid_mask = batch['rids_mask']
    batch_size = len(cid)

    dps, cid_rep_mt, rid_rep_mt = self._encode(cid, rid, cid_mask, rid_mask)
    loss = 0
    # multi-task: recall training
    if self.args['coarse_recall_loss']:
        dp_mt = torch.matmul(cid_rep_mt, rid_rep_mt.t())
        mask = torch.zeros_like(dp_mt)
        mask[range(batch_size), range(batch_size)] = 1.
        loss_ = F.log_softmax(dp_mt, dim=-1) * mask
        loss += (-loss_.sum(dim=1)).mean()

    # multi-task: rerank training (ranking loss)
    ## dp: [B_c, B_r]
    acc = 0
    loss_margin = 0
    for dp in dps:
        gold_score = torch.diagonal(dp).unsqueeze(dim=-1)    # [B_c, 1]
        difference = gold_score - dp    # [B_c, B_r]
        loss_matrix = torch.clamp(self.args['margin'] - difference, min=0.)   # [B_c, B_r]
        loss_margin += loss_matrix.mean()
        
        dp /= self.args['temp']
        mask = torch.zeros_like(dp)
        mask[range(batch_size), range(batch_size)] = 1.
        loss_ = F.log_softmax(dp, dim=-1) * mask
        loss += (-loss_.sum(dim=1)).mean()
        acc += (dp.max(dim=-1)[1] == torch.LongTensor(torch.arange(batch_size)).cuda()).to(torch.float).mean().item()
    acc /= len(dps)
    return loss, loss_margin, acc

class BERTDualSCMEncoder(nn.Module):

def __init__(self, **args):
    super(BERTDualSCMEncoder, self).__init__()
    model = args['pretrained_model']
    self.ctx_encoder = BertEmbedding(model=model)
    self.can_encoder = BertEmbedding(model=model)
    decoder_layer = nn.TransformerDecoderLayer(d_model=768, nhead=args['nhead'])
    self.fusion_encoder = nn.TransformerDecoder(decoder_layer, num_layers=args['num_layers'])
    self.args = args

def _encode(self, cid, rid, cid_mask, rid_mask):
    rid_size, cid_size = len(rid), len(cid)
    # cid_rep_whole: [B_c, S, E]
    cid_rep_whole = self.ctx_encoder(cid, cid_mask, hidden=True)
    # cid_rep: [B_c, E]
    cid_rep = cid_rep_whole[:, 0, :]
    # cid_rep_: [B_c, 1, E]
    cid_rep_ = cid_rep_whole[:, 0, :].unsqueeze(1)
    # rid_rep: [B_r, E]
    # rid_rep = self.can_encoder(rid, rid_mask)
    rid_rep = torch.zeros(rid_size, 768).cuda()

    # cid_rep_mt, rid_rep_mt = self.convert_ctx(cid_rep), self.convert_res(rid_rep)
    cid_rep_mt, rid_rep_mt = cid_rep.clone(), rid_rep.clone()

    ## combine context and response embeddings before comparison
    # cid_rep: [B_r, B_c, E]
    cid_rep = cid_rep.unsqueeze(0).expand(rid_size, -1, -1)
    # rid_rep: [B_r, B_c, E]
    rid_rep = rid_rep.unsqueeze(1).expand(-1, cid_size, -1)
    rep = rid_rep + cid_rep
    # rep: [B_r, B_c, 2*E]

    # cid_rep_whole: [S, B_c, E]
    cid_rep_whole = cid_rep_whole.permute(1, 0, 2)
    # rest: [B_r, B_c, E]
    rest = self.fusion_encoder(
        rep, 
        cid_rep_whole,
        memory_key_padding_mask=~cid_mask.to(torch.bool),
    )

    # rest: [B_c, E, B_r]
    rest = rest.permute(1, 2, 0)
    # dp: [B_c, B_r]
    dp_dp = torch.bmm(cid_rep_, rest).squeeze(1)
    return dp_dp, cid_rep_mt, rid_rep_mt

@torch.no_grad()
def get_cand(self, ids, ids_mask):
    self.eval()
    rest = self.can_encoder(ids, ids_mask)
    rest = self.convert_res(rest)
    return rest

@torch.no_grad()
def get_ctx(self, ids, ids_mask):
    self.eval()
    rest = self.ctx_encoder(ids, ids_mask)
    rest = self.convert_ctx(rest)
    return rest

@torch.no_grad()
def predict(self, batch):
    self.eval()
    cid = batch['ids']
    cid_mask = torch.ones_like(cid)
    rid = batch['rids']
    rid_mask = batch['rids_mask']
    dp, cid_rep, rid_rep = self._encode(cid, rid, cid_mask, rid_mask)    # [1, 10]
    return dp.squeeze()

def forward(self, batch):
    cid = batch['ids']
    rid = batch['rids']
    cid_mask = batch['ids_mask']
    rid_mask = batch['rids_mask']
    batch_size = len(cid)

    dp, cid_rep_mt, rid_rep_mt = self._encode(cid, rid, cid_mask, rid_mask)
    loss, loss_margin = 0, 0
    # multi-task: recall training
    if self.args['coarse_recall_loss']:
        dp_mt = torch.matmul(cid_rep_mt, rid_rep_mt.t())
        mask = torch.zeros_like(dp_mt)
        mask[range(batch_size), range(batch_size)] = 1.
        loss_ = F.log_softmax(dp_mt, dim=-1) * mask
        loss += (-loss_.sum(dim=1)).mean()

    # multi-task: rerank training (ranking loss)
    ## dp: [B_c, B_r]
    # gold_score = torch.diagonal(dp).unsqueeze(dim=-1)    # [B_c, 1]
    # difference = gold_score - dp    # [B_c, B_r]
    # loss_matrix = torch.clamp(self.args['margin'] - difference, min=0.)   # [B_c, B_r]
    # loss_margin += loss_matrix.mean()
    
    mask = torch.zeros_like(dp)
    mask[range(batch_size), range(batch_size)] = 1.
    loss_ = F.log_softmax(dp, dim=-1) * mask
    loss += (-loss_.sum(dim=1)).mean()
    
    acc = (dp.max(dim=-1)[1] == torch.LongTensor(torch.arange(batch_size)).cuda()).to(torch.float).mean().item()
    return loss, acc

class BERTDualSCMHNEncoder(nn.Module):

def __init__(self, **args):
    super(BERTDualSCMHNEncoder, self).__init__()
    model = args['pretrained_model']
    self.ctx_encoder = BertEmbedding(model=model, add_token=1)
    self.can_encoder = BertEmbedding(model=model, add_token=1)
    decoder_layer = nn.TransformerDecoderLayer(d_model=768, nhead=args['nhead'])
    self.fusion_encoder = nn.TransformerDecoder(decoder_layer, num_layers=args['num_layers'])
    self.topk = 1 + args['gray_cand_num']
    self.args = args

def _encode(self, cid, rid, cid_mask, rid_mask, is_test=False, before_comp=False):
    rid_size, cid_size = len(rid), len(cid)
    # cid_rep_whole: [B_c, S, E]
    cid_rep_whole = self.ctx_encoder(cid, cid_mask, hidden=True)
    # cid_rep: [B_c, E]
    cid_rep = cid_rep_whole[:, 0, :]
    # cid_rep_: [B_c, 1, E]
    cid_rep_ = cid_rep_whole[:, 0, :].unsqueeze(1)
    # rid_rep: [B_r*K, E]
    if is_test:
        rid_rep = self.can_encoder(rid, rid_mask)
    else:
        rid_rep = self.can_encoder(rid, rid_mask)
        # rid_rep_whole: [B_r, K, E]
        rid_rep_whole = torch.stack(torch.split(rid_rep, self.topk))
        # rid_rep: [B_r, E]
        rid_rep = rid_rep_whole[:, 0, :]

    ## combine context and response embeddings before comparison
    # rep_cid_backup: [B_r, B_c, E]
    rep_rid = rid_rep.unsqueeze(1).expand(-1, cid_size, -1)
    rep_cid = cid_rep.unsqueeze(0).expand(len(rep_rid), -1, -1)
    # rep: [B_r, B_c, E]
    rep = rep_cid + rep_rid
    # cid_rep_whole: [S, B_c, E]
    cid_rep_whole = cid_rep_whole.permute(1, 0, 2)
    # rest: [B_r, B_c, E]
    rest = self.fusion_encoder(
        rep, 
        cid_rep_whole,
        memory_key_padding_mask=~cid_mask.to(torch.bool),
    )
    # rest: [B_c, E, B_r]
    rest = rest.permute(1, 2, 0)
    # dp: [B_c, B_r]
    dp = torch.bmm(cid_rep_, rest).squeeze(1)
    if is_test:
        return dp, cid_rep, rid_rep

    ### hard negative comparison
    # rid_rep_whole: [K, B_r, E], rep_rid: [K, B_r, E]
    rep_rid = rid_rep_whole.permute(1, 0, 2)
    # rep_cid: [K, B_c, E]
    rep_cid = cid_rep.unsqueeze(0).expand(len(rep_rid), -1, -1)
    # rep: [B_r, B_c, E]
    rep = rep_cid + rep_rid
    # rest: [K, B_r, E]
    rest = self.fusion_encoder(
        rep, 
        cid_rep_whole,
        memory_key_padding_mask=~cid_mask.to(torch.bool),
    )
    # rest: [K, B_r, E] -> [B_r, E, K]
    rest = rest.permute(1, 2, 0)
    # dp: [B_c, K]
    dp2 = torch.bmm(cid_rep_, rest).squeeze(1)
    if before_comp:
        return dp, dp2, cid_rep, rid_rep
    else:
        return dp, dp2

@torch.no_grad()
def predict(self, batch):
    cid = batch['ids']
    cid_mask = torch.ones_like(cid)
    rid = batch['rids']
    rid_mask = batch['rids_mask']
    dp, cid_rep, rid_rep = self._encode(cid, rid, cid_mask, rid_mask, is_test=True)    # [1, 10]
    # dp = torch.matmul(cid_rep, rid_rep.t())
    dp = F.softmax(dp.squeeze(), dim=-1)
    return dp

def forward(self, batch):
    cid = batch['ids']
    # rid: [B_r*K, S]
    rid = batch['rids']
    cid_mask = batch['ids_mask']
    rid_mask = batch['rids_mask']
    batch_size = len(cid)

    # [B_c, B_r]
    loss = 0
    if self.args['before_comp']:
        dp, dp2, cid_rep, rid_rep = self._encode(cid, rid, cid_mask, rid_mask, before_comp=True)
        # before comparsion, optimize the absolute semantic space
        dot_product = torch.matmul(cid_rep, rid_rep.t())
        mask = torch.zeros_like(dot_product)
        mask[range(batch_size), range(batch_size)] = 1.
        loss_ = F.log_softmax(dot_product, dim=-1) * mask
        loss += (-loss_.sum(dim=1)).mean()
    else:
        dp, dp2 = self._encode(cid, rid, cid_mask, rid_mask)
    mask = torch.zeros_like(dp)
    mask[range(batch_size), range(batch_size)] = 1.
    loss_ = F.log_softmax(dp, dim=-1) * mask
    loss += (-loss_.sum(dim=1)).mean()

    mask = torch.zeros_like(dp2)
    mask[:, 0] = 1.
    loss_ = F.log_softmax(dp2, dim=-1) * mask
    loss += (-loss_.sum(dim=1)).mean()

    acc = (dp.max(dim=-1)[1] == torch.LongTensor(torch.arange(batch_size)).cuda()).to(torch.float).mean().item()
    return loss, acc

class BERTDualSCMHN2Encoder(nn.Module):

'''easy compare, hard compare, easy and hard compare'''

def __init__(self, **args):
    super(BERTDualSCMHN2Encoder, self).__init__()
    model = args['pretrained_model']
    self.ctx_encoder = BertEmbedding(model=model)
    self.can_encoder = BertEmbedding(model=model)
    decoder_layer = nn.TransformerDecoderLayer(d_model=768, nhead=args['nhead'])
    self.fusion_encoder = nn.TransformerDecoder(decoder_layer, num_layers=args['num_layers'])
    self.topk = 1 + args['gray_cand_num']
    self.args = args

def _encode(self, cid, rid, cid_mask, rid_mask, is_test=False, before_comp=False):
    rid_size, cid_size = len(rid), len(cid)
    # cid_rep_whole: [B_c, S, E]
    cid_rep_whole = self.ctx_encoder(cid, cid_mask, hidden=True)
    # cid_rep: [B_c, E]
    cid_rep = cid_rep_whole[:, 0, :]
    # cid_rep_: [B_c, 1, E]
    cid_rep_ = cid_rep_whole[:, 0, :].unsqueeze(1)
    # rid_rep: [B_r*K, E]
    if is_test:
        rid_rep = self.can_encoder(rid, rid_mask)
    else:
        rid_rep = self.can_encoder(rid, rid_mask)
        # rid_rep_whole: [B_r, K, E]
        rid_rep_whole = torch.stack(torch.split(rid_rep, self.topk))
        # rid_rep: [B_r, E]
        rid_rep = rid_rep_whole[:, 0, :]

    ### easy comparison
    # rep_cid_backup: [B_r, B_c, E]
    rep_rid = rid_rep.unsqueeze(1).expand(-1, cid_size, -1)
    rep_cid = cid_rep.unsqueeze(0).expand(len(rep_rid), -1, -1)
    # rep: [B_r, B_c, E]
    rep = rep_cid + rep_rid
    # cid_rep_whole: [S, B_c, E]
    cid_rep_whole = cid_rep_whole.permute(1, 0, 2)
    # rest: [B_r, B_c, E]
    rest = self.fusion_encoder(
        rep, 
        cid_rep_whole,
        memory_key_padding_mask=~cid_mask.to(torch.bool),
    )
    # rest: [B_c, E, B_r]
    rest = rest.permute(1, 2, 0)
    # dp: [B_c, B_r]
    dp = torch.bmm(cid_rep_, rest).squeeze(1)
    if is_test:
        return dp, cid_rep, rid_rep

    ### hard negative comparison
    # rid_rep_whole: [K, B_r, E], rep_rid: [K, B_r, E]
    rep_rid = rid_rep_whole.permute(1, 0, 2)
    # rep_cid: [K, B_c, E]
    rep_cid = cid_rep.unsqueeze(0).expand(len(rep_rid), -1, -1)
    # rep: [B_r, B_c, E]
    rep = rep_cid + rep_rid
    # rest: [K, B_r, E]
    rest = self.fusion_encoder(
        rep, 
        cid_rep_whole,
        memory_key_padding_mask=~cid_mask.to(torch.bool),
    )
    # rest: [K, B_r, E] -> [B_r, E, K]
    rest = rest.permute(1, 2, 0)
    # dp: [B_c, K]
    dp2 = torch.bmm(cid_rep_, rest).squeeze(1)

    ### easy and hard comparison
    en_part = rid_rep_whole[:, 0, :]    # [B_r, E]
    size = len(en_part)
    en_part = en_part.unsqueeze(0).expand(size, -1, -1)    # [B_r, B_r, E]
    rest = []
    for i in range(size):
        index = list(range(size))
        index.remove(i)
        rest.append(en_part[i, index, :])
    rest = torch.stack(rest)    # [B_r, B_r-1, E]
    rep_rid = torch.cat([rid_rep_whole, rest], dim=1).permute(1, 0, 2)    # [B_r+K-1, B_r, E]
    rep_cid = cid_rep.unsqueeze(0).expand(len(rep_rid), -1, -1)     # [B_r+K-1, B_c, E]
    # rep: [B_r*K, B_c, E]
    rep = rep_cid + rep_rid
    # cid_rep_whole: [S, B_c, E]
    cid_rep_whole = cid_rep_whole.permute(1, 0, 2)
    # rest: [B_r*K, B_c, E]
    rest = self.fusion_encoder(
        rep, 
        cid_rep_whole,
        memory_key_padding_mask=~cid_mask.to(torch.bool),
    )
    # rest: [B_c, E, B_r*K]
    rest = rest.permute(1, 2, 0)
    # dp: [B_c, B_r*K]
    dp3 = torch.bmm(cid_rep_, rest).squeeze(1)

    if before_comp:
        return dp, dp2, dp3, cid_rep, rid_rep
    else:
        return dp, dp2, dp3

@torch.no_grad()
def predict(self, batch):
    cid = batch['ids']
    cid_mask = torch.ones_like(cid)
    rid = batch['rids']
    rid_mask = batch['rids_mask']
    dp, cid_rep, rid_rep = self._encode(cid, rid, cid_mask, rid_mask, is_test=True)    # [1, 10]
    # dp = torch.matmul(cid_rep, rid_rep.t())
    return dp.squeeze()

def forward(self, batch):
    cid = batch['ids']
    # rid: [B_r*K, S]
    rid = batch['rids']
    cid_mask = batch['ids_mask']
    rid_mask = batch['rids_mask']
    batch_size = len(cid)

    # [B_c, B_r]
    loss = 0
    if self.args['before_comp']:
        dp, dp2, dp3, cid_rep, rid_rep = self._encode(cid, rid, cid_mask, rid_mask, before_comp=True)
        # before comparsion, optimize the absolute semantic space
        dot_product = torch.matmul(cid_rep, rid_rep.t())
        mask = torch.zeros_like(dot_product)
        mask[range(batch_size), range(batch_size)] = 1.
        loss_ = F.log_softmax(dot_product, dim=-1) * mask
        loss += (-loss_.sum(dim=1)).mean()
    else:
        dp, dp2, dp3 = self._encode(cid, rid, cid_mask, rid_mask)
    mask = torch.zeros_like(dp)
    mask[range(batch_size), range(batch_size)] = 1.
    loss_ = F.log_softmax(dp, dim=-1) * mask
    loss += (-loss_.sum(dim=1)).mean()

    mask = torch.zeros_like(dp2)
    mask[:, 0] = 1.
    loss_ = F.log_softmax(dp2, dim=-1) * mask
    loss += (-loss_.sum(dim=1)).mean()
    
    mask = torch.zeros_like(dp3)
    mask[range(batch_size), 0] = 1.
    loss_ = F.log_softmax(dp3, dim=-1) * mask
    loss += (-loss_.sum(dim=1)).mean()

    acc = (dp.max(dim=-1)[1] == torch.LongTensor(torch.arange(batch_size)).cuda()).to(torch.float).mean().item()
    return loss, acc

class BERTDualSCMHNL2REncoder(nn.Module):

def __init__(self, **args):
    super(BERTDualSCMHNL2REncoder, self).__init__()
    model = args['pretrained_model']
    self.ctx_encoder = BertEmbedding(model=model)
    self.can_encoder = BertEmbedding(model=model)
    decoder_layer = nn.TransformerDecoderLayer(d_model=768, nhead=args['nhead'])
    self.fusion_encoder = nn.TransformerDecoder(decoder_layer, num_layers=args['num_layers'])
    self.linear = nn.Sequential(
        nn.Dropout(p=args['dropout']) ,
        nn.Linear(768, 1)
    )
    #poo self.position_embd = nn.Embedding(512, 768)
    self.criterion = nn.CrossEntropyLoss()
    self.topk = 1 + args['gray_cand_num']
    self.args = args

def _encode(self, cid, rid, cid_mask, rid_mask, is_test=False, before_comp=False):
    rid_size, cid_size = len(rid), len(cid)
    # cid_rep_whole: [B_c, S, E]
    cid_rep_whole = self.ctx_encoder(cid, cid_mask, hidden=True)
    # cid_rep: [B_c, E]
    cid_rep = cid_rep_whole[:, 0, :]
    # cid_rep_: [B_c, 1, E]
    cid_rep_ = cid_rep_whole[:, 0, :].unsqueeze(1)
    # rid_rep: [B_r*K, E]
    if is_test:
        rid_rep = self.can_encoder(rid, rid_mask)
    else:
        rid_rep = self.can_encoder(rid, rid_mask)
        # rid_rep_whole: [B_r, K, E]
        rid_rep_whole = torch.stack(torch.split(rid_rep, self.topk))
        # rid_rep: [B_r, E]
        rid_rep = rid_rep_whole[:, 0, :]

    ## combine context and response embeddings before comparison
    # rep_cid_backup: [B_r, B_c, E]
    rep_rid = rid_rep.unsqueeze(1).expand(-1, cid_size, -1)
    rep_cid = cid_rep.unsqueeze(0).expand(len(rep_rid), -1, -1)
    # pos_index = torch.arange(cid_size).cuda().unsqueeze(dim=-1).expand(-1, cid_size)    # [B_r, B_c]
    # rep_pos = self.position_embd(pos_index)    # [B_r, B_c, E]
    # rep: [B_r, B_c, E]
    rep = rep_cid + rep_rid
    # cid_rep_whole: [S, B_c, E]
    cid_rep_whole = cid_rep_whole.permute(1, 0, 2)
    # rest: [B_r, B_c, E]
    rest = self.fusion_encoder(
        rep, 
        cid_rep_whole,
        memory_key_padding_mask=~cid_mask.to(torch.bool),
    )
    # rest: [B_c, B_r, E]
    rest = rest.permute(1, 0, 2)
    rest = self.linear(rest).squeeze(-1)    # [B_c, B_e]
    if is_test:
        return rest, cid_rep, rid_rep

    ### hard negative comparison
    # rid_rep_whole: [K, B_r, E], rep_rid: [K, B_r, E]
    rep_rid = rid_rep_whole.permute(1, 0, 2)
    # rep_cid: [K, B_c, E]
    rep_cid = cid_rep.unsqueeze(0).expand(len(rep_rid), -1, -1)
    # pos_index = torch.arange(self.topk).cuda().unsqueeze(dim=-1).expand(-1, cid_size)    # [K, B_c]
    # rep_pos = self.position_embd(pos_index)
    # rep: [B_r, B_c, E]
    rep = rep_cid + rep_rid
    # rest: [K, B_r, E]
    rest2 = self.fusion_encoder(
        rep, 
        cid_rep_whole,
        memory_key_padding_mask=~cid_mask.to(torch.bool),
    )
    # rest: [K, B_r, E] -> [B_r, K, E]
    rest2 = rest2.permute(1, 0, 2)
    rest2 = self.linear(rest2).squeeze(dim=-1)    # [B_c, K]

    ### hard negative and few easy negative
    rep_rid = rid_rep_whole.reshape(-1, 768).unsqueeze(1).expand(-1, cid_size, -1)
    rep_cid = cid_rep.unsqueeze(0).expand(len(rep_rid), -1, -1)
    # pos_index = torch.arange(len(rep_rid)).cuda().unsqueeze(dim=-1).expand(-1, cid_size)    # [B_r*K, B_c]
    # rep_pos = self.position_embd(pos_index)
    # rep: [B_r*K, B_c, E]
    rep = rep_cid + rep_rid 
    # cid_rep_whole: [S, B_c, E]
    cid_rep_whole = cid_rep_whole.permute(1, 0, 2)
    # rest: [B_r*K, B_c, E]
    rest3 = self.fusion_encoder(
        rep, 
        cid_rep_whole,
        memory_key_padding_mask=~cid_mask.to(torch.bool),
    )
    # rest: [B_c, B_r*K, E]
    rest3 = rest3.permute(1, 0, 2)
    rest3 = self.linear(rest3).squeeze(-1)    # [B_c, B_r*K]

    if before_comp:
        return rest, rest2, rest3, cid_rep, rid_rep
    else:
        return rest, rest2, rest3

@torch.no_grad()
def predict(self, batch):
    cid = batch['ids']
    cid_mask = torch.ones_like(cid)
    rid = batch['rids']
    rid_mask = batch['rids_mask']
    rest, cid_rep, rid_rep = self._encode(cid, rid, cid_mask, rid_mask, is_test=True)    # [1, 10]
    rest = F.softmax(rest.squeeze(dim=0), dim=-1)
    return rest

def forward(self, batch):
    cid = batch['ids']
    # rid: [B_r*K, S]
    rid = batch['rids']
    cid_mask = batch['ids_mask']
    rid_mask = batch['rids_mask']
    batch_size = len(cid)

    # [B_c, B_r]
    loss = 0
    if self.args['before_comp']:
        rest, rest2, rest3, cid_rep, rid_rep = self._encode(cid, rid, cid_mask, rid_mask, before_comp=True)
        # before comparsion, optimize the absolute semantic space
        dot_product = torch.matmul(cid_rep, rid_rep.t())
        mask = torch.zeros_like(dot_product)
        mask[range(batch_size), range(batch_size)] = 1.
        loss_ = F.log_softmax(dot_product, dim=-1) * mask
        loss += (-loss_.sum(dim=1)).mean()
    else:
        rest, rest2, rest3 = self._encode(cid, rid, cid_mask, rid_mask)
    # rest: [B_c, B_r]; rest2: [B_c, K]
    rest_label = torch.arange(batch_size).cuda()
    rest_label_2 = torch.zeros(batch_size).cuda().to(torch.long)
    rest_label_3 = torch.arange(0, rest3.size(-1), self.topk).cuda().to(torch.long)

    rest /= self.args['temp']
    rest2 /= self.args['temp']
    rest3 /= self.args['temp']

    loss = self.criterion(rest, rest_label)
    loss += self.criterion(rest2, rest_label_2)
    loss += self.criterion(rest3, rest_label_3)
    acc = (rest.max(dim=-1)[1] == torch.arange(batch_size).cuda()).to(torch.float).mean().item()
    return loss, acc

class BERTDualSCMPairwiseEncoder(nn.Module):

def __init__(self, **args):
    super(BERTDualSCMPairwiseEncoder, self).__init__()
    model = args['pretrained_model']
    self.ctx_encoder = BertEmbedding(model=model)
    self.can_encoder = BertEmbedding(model=model)
    decoder_layer = nn.TransformerDecoderLayer(d_model=768, nhead=args['nhead'])
    self.fusion_encoder = nn.TransformerDecoder(decoder_layer, num_layers=args['num_layers'])
    self.args = args

def _encode(self, cid, rid, cid_mask, rid_mask):
    rid_size, cid_size = len(rid), len(cid)
    # cid_rep_whole: [B_c, S, E]
    cid_rep_whole = self.ctx_encoder(cid, cid_mask, hidden=True)
    # cid_rep: [B_c, E]
    cid_rep = cid_rep_whole[:, 0, :]
    # cid_rep_: [B_c, 1, E]
    cid_rep_ = cid_rep_whole[:, 0, :].unsqueeze(1)
    # rid_rep: [B_r, E]
    # rid_rep = self.can_encoder(rid, rid_mask)
    rid_rep = torch.zeros(rid_size, 768).cuda()

    # cid_rep_mt, rid_rep_mt = self.convert_ctx(cid_rep), self.convert_res(rid_rep)
    cid_rep_mt, rid_rep_mt = cid_rep.clone(), rid_rep.clone()

    ## combine context and response embeddings before comparison
    # cid_rep: [B_r, B_c, E]
    cid_rep = cid_rep.unsqueeze(0).expand(rid_size, -1, -1)
    # rid_rep: [B_r, B_c, E]
    rid_rep = rid_rep.unsqueeze(1).expand(-1, cid_size, -1)
    rep = rid_rep + cid_rep
    # rep: [B_r, B_c, 2*E]

    # cid_rep_whole: [S, B_c, E]
    cid_rep_whole = cid_rep_whole.permute(1, 0, 2)
    # rest: [B_r, B_c, E]
    rest = self.fusion_encoder(
        rep, 
        cid_rep_whole,
        memory_key_padding_mask=~cid_mask.to(torch.bool),
    )

    # rest: [B_c, E, B_r]
    rest = rest.permute(1, 2, 0)
    # dp: [B_c, B_r]
    dp_dp = torch.bmm(cid_rep_, rest).squeeze(1)
    return dp_dp, cid_rep_mt, rid_rep_mt

@torch.no_grad()
def get_cand(self, ids, ids_mask):
    self.eval()
    rest = self.can_encoder(ids, ids_mask)
    rest = self.convert_res(rest)
    return rest

@torch.no_grad()
def get_ctx(self, ids, ids_mask):
    self.eval()
    rest = self.ctx_encoder(ids, ids_mask)
    rest = self.convert_ctx(rest)
    return rest

@torch.no_grad()
def predict(self, batch):
    self.eval()
    cid = batch['ids']
    cid_mask = torch.ones_like(cid)
    rid = batch['rids']
    rid_mask = batch['rids_mask']
    dp, cid_rep, rid_rep = self._encode(cid, rid, cid_mask, rid_mask)    # [1, 10]
    return dp.squeeze()

def forward(self, batch):
    cid = batch['ids']
    rid = batch['rids']
    cid_mask = batch['ids_mask']
    rid_mask = batch['rids_mask']
    batch_size = len(cid)

    dp, cid_rep_mt, rid_rep_mt = self._encode(cid, rid, cid_mask, rid_mask)
    loss, loss_margin = 0, 0
    # multi-task: recall training
    if self.args['coarse_recall_loss']:
        dp_mt = torch.matmul(cid_rep_mt, rid_rep_mt.t())
        mask = torch.zeros_like(dp_mt)
        mask[range(batch_size), range(batch_size)] = 1.
        loss_ = F.log_softmax(dp_mt, dim=-1) * mask
        loss += (-loss_.sum(dim=1)).mean()

    mask = torch.zeros_like(dp)
    mask[range(batch_size), range(batch_size)] = 1.
    loss_ = F.log_softmax(dp, dim=-1) * mask
    loss += (-loss_.sum(dim=1)).mean()
    
    acc = (dp.max(dim=-1)[1] == torch.LongTensor(torch.arange(batch_size)).cuda()).to(torch.float).mean().item()
    return loss, ACC

Thanks!

I cannot find this code in the repo. Are your results for DR-BERT(as mentioned in the paper) extracted from this model or some other?