kd loss
PaTricksStar opened this issue · 21 comments
why softmax for teacher output , but log softmax for student output ?
it would better help me understand your question if you could mention which file & lines that you were referring to.
oh It is in model /net.py/ loss_fn_kd function, line 107.
I see. You can refer to the definition/document of PyTorch's KL Divergence loss (KLDivLos). Here it requires inputs to be probability distributions and log-probability distributions, and that's why we're using softmax and log-softmax on teacher/student outputs (which were raw scores).
Thanks for your reply.
@peterliht Why KL divergence is used to compute KD-loss? The paper "Distilling the Knowledge in a Neural Network" says,
The first objective function is the cross entropy with the soft targets
The KD-loss should be -\sum_{i=1}^C soft_target_i * \log(softmax(student_cls_output_i / T))
?
@xmfbit Indeed, initially I was trying to directly implement cross entropy with the soft targets. However, note in PyTorch, the built-in CrossEntropy loss function only takes “(output, target)” where the target (i.e., label) is not one-hot encoded (which is what KD loss needs). That's why I turned to using KL divergence, since they two will lead to the same optimization results, and KL divergence works naturally with our data representations.
You're welcome to try to define a customized CrossEntropy loss function that also leverages PyTorch’s optimized C-backend (you could also define one from scratch, but that might be very slow). If successful, please let us know. Thanks!
Ok. I understand your words. This is my naive implementation. I only tested it on MNIST, so the speed is not very important
x, target = x.to(device), target.to(device)
with torch.no_grad():
out = teacher(x)
soft_target = F.softmax(out/T, dim=1)
hard_target = target
out = student(x) ## this is the input to softmax
logp = F.log_softmax(out/T, dim=1)
loss_soft_target = -torch.mean(torch.sum(soft_target * logp, dim=1))
loss_hard_target = nn.CrossEntropyLoss()(out, hard_target)
loss = loss_soft_target * T * T + alpha * loss_hard_target
@xmfbit H(p) is constant, right ?
yes,it do not join backward.so wo can ignore this term
Has anyone tried cross-entropy? Does it work better or worse than KL?
Has anyone tried cross-entropy? Does it work better or worse than KL?
No. They should lead to same or similar result given the above discussion .
Has anyone tried cross-entropy? Does it work better or worse than KL?
@michaelklachko The gradients of student's output are same using KL divergence and classic KD loss by Hinton's paper. You can refer to the figure given by nowgood. Use this code to check it numerically (p, q are different from nowgood's figure).
import torch
import torch.nn as nn
import torch.nn.functional as F
N = 10
C = 5
# softmax output by teacher
p = torch.softmax(torch.rand(N, C), dim=1)
# softmax output by student
q = torch.softmax(torch.rand(N, C), dim=1)
#q = torch.ones(N, C)
q.requires_grad = True
# KL Diverse
kl_loss = nn.KLDivLoss()(torch.log(q), p)
kl_loss.backward()
grad = q.grad
q.grad.zero_()
ce_loss = torch.mean(torch.log(q) * p)
ce_loss.backward()
grad_check = q.grad
print grad
print grad_check
Great, thanks!
Great, thanks!
@michaelklachko Sorry but I found there are some bugs in the code I provided. (when checking the gradient using print, the grad
refer to the same instance with grad_check
, so they must be equal! ) I paste a right one below.
And I want to figure out that the author's implementation of KD loss using torch.mm.KLDivLoss
in the code will cause the gradient scaled by the number of classification categories
, compared with using CrossEntropy
. See the torch documentation for detail.
size_average (bool, optional) – Deprecated (see reduction). By default, the losses are averaged over each loss element in the batch. Note that for some losses, there multiple elements per sample.
reduction (string, optional) – Specifies the reduction to apply to the output: ‘none’ | ‘elementwise_mean’ | ‘sum’. ‘none’: no reduction will be applied, ‘elementwise_mean’: the sum of the output will be divided by the number of elements in the output, ‘sum’: the output will be summed. Note: size_average and reduce are in the process of being deprecated, and in the meantime, specifying either of those two args will override reduction. Default: ‘elementwise_mean’
import torch
import torch.nn as nn
import torch.nn.functional as F
# sample number
N = 10
# category number
C = 5
# softmax output of teacher
p = torch.softmax(torch.rand(N, C), dim=1)
# logit output of student
s = torch.rand(N, C, requires_grad=True)
# softmax output of student, T = 1
q = torch.softmax(s, dim=1)
# KL Diverse
# this is the implementation of the author's
# torch will do element mean because it is the default option
# kl_loss = nn.KLDivLoss()(torch.log(q), p)
# I think this should be the right solution
kl_loss = (nn.KLDivLoss(reduction='none')(torch.log(q), p)).sum(dim=1).mean()
kl_loss.backward(retain_graph=True)
print 'grad using KL DivLoss'
print s.grad
# clear the grad
s.grad.zero_()
# bug2: should not do element wise mean operation
# ce_loss = torch.mean(-torch.log(q) * p)
ce_loss = torch.mean(torch.sum(-torch.log(q) * p, dim=1))
ce_loss.backward()
print 'grad using ce loss'
print s.grad
# the real gradient of s should be `(q - p) / batch_size`
print 'real grad, should be (q-p) / batch_size'
print (q - p) / N
@peterliht Could you check this?
Great, thanks!
@michaelklachko Sorry but I found there are some bugs in the code I provided. (when checking the gradient using print, the
grad
refer to the same instance withgrad_check
, so they must be equal! ) I paste a right one below.And I want to figure out that the author's implementation of KD loss using
torch.mm.KLDivLoss
in the code will cause the gradient scaled bythe number of classification categories
, compared with usingCrossEntropy
. See the torch documentation for detail.size_average (bool, optional) – Deprecated (see reduction). By default, the losses are averaged over each loss element in the batch. Note that for some losses, there multiple elements per sample.
reduction (string, optional) – Specifies the reduction to apply to the output: ‘none’ | ‘elementwise_mean’ | ‘sum’. ‘none’: no reduction will be applied, ‘elementwise_mean’: the sum of the output will be divided by the number of elements in the output, ‘sum’: the output will be summed. Note: size_average and reduce are in the process of being deprecated, and in the meantime, specifying either of those two args will override reduction. Default: ‘elementwise_mean’
import torch import torch.nn as nn import torch.nn.functional as F # sample number N = 10 # category number C = 5 # softmax output of teacher p = torch.softmax(torch.rand(N, C), dim=1) # logit output of student s = torch.rand(N, C, requires_grad=True) # softmax output of student, T = 1 q = torch.softmax(s, dim=1) # KL Diverse # this is the implementation of the author's # torch will do element mean because it is the default option # kl_loss = nn.KLDivLoss()(torch.log(q), p) # I think this should be the right solution kl_loss = (nn.KLDivLoss(reduction='none')(torch.log(q), p)).sum(dim=1).mean() kl_loss.backward(retain_graph=True) print 'grad using KL DivLoss' print s.grad # clear the grad s.grad.zero_() # bug2: should not do element wise mean operation # ce_loss = torch.mean(-torch.log(q) * p) ce_loss = torch.mean(torch.sum(-torch.log(q) * p, dim=1)) ce_loss.backward() print 'grad using ce loss' print s.grad # the real gradient of s should be `(q - p) / batch_size` print 'real grad, should be (q-p) / batch_size' print (q - p) / N
@peterliht Could you check this?
when I tried to use the loss code provided by author in new task which has 1000 categories, I found the kl loss term is too small , nearly 1e-6, and both in CIFAR10 and my task, i seems the kl loss is never decrease. Can you tell me how to fix this problem, thank you.
Great, thanks!
@michaelklachko Sorry but I found there are some bugs in the code I provided. (when checking the gradient using print, the
grad
refer to the same instance withgrad_check
, so they must be equal! ) I paste a right one below.
And I want to figure out that the author's implementation of KD loss usingtorch.mm.KLDivLoss
in the code will cause the gradient scaled bythe number of classification categories
, compared with usingCrossEntropy
. See the torch documentation for detail.size_average (bool, optional) – Deprecated (see reduction). By default, the losses are averaged over each loss element in the batch. Note that for some losses, there multiple elements per sample.
reduction (string, optional) – Specifies the reduction to apply to the output: ‘none’ | ‘elementwise_mean’ | ‘sum’. ‘none’: no reduction will be applied, ‘elementwise_mean’: the sum of the output will be divided by the number of elements in the output, ‘sum’: the output will be summed. Note: size_average and reduce are in the process of being deprecated, and in the meantime, specifying either of those two args will override reduction. Default: ‘elementwise_mean’
import torch import torch.nn as nn import torch.nn.functional as F # sample number N = 10 # category number C = 5 # softmax output of teacher p = torch.softmax(torch.rand(N, C), dim=1) # logit output of student s = torch.rand(N, C, requires_grad=True) # softmax output of student, T = 1 q = torch.softmax(s, dim=1) # KL Diverse # this is the implementation of the author's # torch will do element mean because it is the default option # kl_loss = nn.KLDivLoss()(torch.log(q), p) # I think this should be the right solution kl_loss = (nn.KLDivLoss(reduction='none')(torch.log(q), p)).sum(dim=1).mean() kl_loss.backward(retain_graph=True) print 'grad using KL DivLoss' print s.grad # clear the grad s.grad.zero_() # bug2: should not do element wise mean operation # ce_loss = torch.mean(-torch.log(q) * p) ce_loss = torch.mean(torch.sum(-torch.log(q) * p, dim=1)) ce_loss.backward() print 'grad using ce loss' print s.grad # the real gradient of s should be `(q - p) / batch_size` print 'real grad, should be (q-p) / batch_size' print (q - p) / N
@peterliht Could you check this?
when I tried to use the loss code provided by author in new task which has 1000 categories, I found the kl loss term is too small , nearly 1e-6, and both in CIFAR10 and my task, i seems the kl loss is never decrease. Can you tell me how to fix this problem, thank you.
Try to decrease your lr or train with CE loss to check bugs
@Bo396543018 The same here, did you solve the problem?
Hi,
Using #2 (comment) implementation, I am getting large Soft loss values like 200, but using https://github.com/peterliht/knowledge-distillation-pytorch/blob/master/model/net.py#L100 implementation, my Soft loss value is like 1e-6. It's surprising because we expect results to be similar.
Any help will be grateful. Thanks!
So what is the answer for: why softmax for teacher output , but log softmax for student output ?
@pratikchhapolika
https://pytorch.org/docs/stable/generated/torch.nn.functional.kl_div.html#torch.nn.functional.kl_div
it's just how pytorch KLDIvLoss()
takes the arguments
input=predicted log softmax
target=softmax if log_target==False
as shown above KL Divergence = CrossEntropy - Entropy
log(p(y)) = student output
q(y) = teacher output