Loss
PeterLuoCoder opened this issue · 0 comments
PeterLuoCoder commented
"Hello, I'm studying your code and have a few questions. I apologize for the interruption, but could you please explain the following to me? In the train.py file, are use_mixup and use_edge calling other functions? I couldn't find any other files related to them. Also, is the loss calculated as CrossEntropyLoss() + Edge_loss()? If I want to calculate only the CrossEntropyLoss, how should I modify it? Thank you."
class FullModel(nn.Module):
def __init__(self, model, args2):
super(FullModel, self).__init__()
self.model = model
self.use_mixup = args2.use_mixup
self.use_edge = args2.use_edge
# self.ce_loss = Edge_weak_loss()
self.ce_loss = CrossEntropyLoss()
self.edge_loss = Edge_loss()
if self.use_mixup:
self.mixup = Mixup(use_edge=args2.use_edge)
def forward(self, input, label=None, train=True):
if train and self.use_mixup and label is not None:
if self.use_edge:
loss = self.mixup(input, label, [self.ce_loss, self.edge_loss], self.model)
else:
loss = self.mixup(input, label, self.ce_loss, self.model)
return loss
output = self.model(input)
if train:
losses = 0
if isinstance(output, (list, tuple)):
if self.use_edge:
for i in range(len(output) - 1):
loss = self.ce_loss(output[i], label)
losses += loss
losses += self.edge_loss(output[-1], edge_contour(label).long())
else:
for i in range(len(output)):
loss = self.ce_loss(output[i], label)
losses += loss
else:
losses = self.ce_loss(output, label)
return losses
else:
if isinstance(output, (list, tuple)):
return output[0]
else:
return output