Which part of the loss function does the 'lossinner' correspond to ?
zhrli opened this issue · 1 comments
zhrli commented
class RAINCOAT(Algorithm):
def init(self, configs, hparams, device):
super(RAINCOAT, self).init(configs)
self.feature_extractor = tf_encoder(configs).to(device)
self.decoder = tf_decoder(configs).to(device)
self.classifier = classifier(configs).to(device)
self.optimizer = torch.optim.Adam(
list(self.feature_extractor.parameters()) + \
# list(self.decoder.parameters())+\
list(self.classifier.parameters()),
lr=hparams["learning_rate"],
weight_decay=hparams["weight_decay"]
)
self.coptimizer = torch.optim.Adam(
list(self.feature_extractor.parameters())+list(self.decoder.parameters()),
lr=0.5*hparams["learning_rate"],
weight_decay=hparams["weight_decay"]
)
self.hparams = hparams
self.recons = nn.L1Loss(reduction='sum').to(device)
self.pi = torch.acos(torch.zeros(1)).item() * 2
self.loss_func = losses.ContrastiveLoss(pos_margin=0.5)
self.sink = SinkhornDistance(eps=1e-3, max_iter=1000, reduction='sum')
def update(self, src_x, src_y, trg_x):
self.optimizer.zero_grad()
src_feat, out_s = self.feature_extractor(src_x)
trg_feat, out_t = self.feature_extractor(trg_x)
src_recon = self.decoder(src_feat, out_s)
trg_recon = self.decoder(trg_feat, out_t)
recons = 1e-4*(self.recons(src_recon, src_x)+self.recons(trg_recon, trg_x))
recons.backward(retain_graph=True)
dr, _, _ = self.sink(src_feat, trg_feat)
sink_loss = 1 *dr
sink_loss.backward(retain_graph=True)
lossinner = 1 * self.loss_func(src_feat, src_y)
lossinner.backward(retain_graph=True)
src_pred = self.classifier(src_feat)
loss_cls = 1 *self.cross_entropy(src_pred, src_y)
loss_cls.backward(retain_graph=True)
self.optimizer.step()
return {'Src_cls_loss': loss_cls.item(),'Sink': sink_loss.item(), 'inner': lossinner.item()}
hehuannb commented
Hi, lossinner is not used and it was used for testing. We have update the code. Thank you!