mims-harvard/Raincoat

Which part of the loss function does the 'lossinner' correspond to ?

zhrli opened this issue · 1 comments

zhrli commented

class RAINCOAT(Algorithm):
def init(self, configs, hparams, device):
super(RAINCOAT, self).init(configs)
self.feature_extractor = tf_encoder(configs).to(device)
self.decoder = tf_decoder(configs).to(device)
self.classifier = classifier(configs).to(device)

    self.optimizer = torch.optim.Adam(
        list(self.feature_extractor.parameters()) + \
            # list(self.decoder.parameters())+\
            list(self.classifier.parameters()),
        lr=hparams["learning_rate"],
        weight_decay=hparams["weight_decay"]
    )
    self.coptimizer = torch.optim.Adam(
        list(self.feature_extractor.parameters())+list(self.decoder.parameters()),
        lr=0.5*hparams["learning_rate"],
        weight_decay=hparams["weight_decay"]
    )
        
    self.hparams = hparams
    self.recons = nn.L1Loss(reduction='sum').to(device)
    self.pi = torch.acos(torch.zeros(1)).item() * 2
    self.loss_func = losses.ContrastiveLoss(pos_margin=0.5)
    self.sink = SinkhornDistance(eps=1e-3, max_iter=1000, reduction='sum')
    
def update(self, src_x, src_y, trg_x):

    self.optimizer.zero_grad()
    src_feat, out_s = self.feature_extractor(src_x)
    trg_feat, out_t = self.feature_extractor(trg_x)
    src_recon = self.decoder(src_feat, out_s)
    trg_recon = self.decoder(trg_feat, out_t)
    recons = 1e-4*(self.recons(src_recon, src_x)+self.recons(trg_recon, trg_x))
    recons.backward(retain_graph=True)
    dr, _, _ = self.sink(src_feat, trg_feat)
    sink_loss = 1 *dr
    sink_loss.backward(retain_graph=True)
    lossinner = 1 * self.loss_func(src_feat, src_y)
    lossinner.backward(retain_graph=True)
    src_pred = self.classifier(src_feat)
    loss_cls = 1 *self.cross_entropy(src_pred, src_y) 
    loss_cls.backward(retain_graph=True)
    self.optimizer.step()
    return {'Src_cls_loss': loss_cls.item(),'Sink': sink_loss.item(), 'inner': lossinner.item()}

Hi, lossinner is not used and it was used for testing. We have update the code. Thank you!