torch-BlazefaceDistillation
Abstract
This is Distillation for Converting performance of MobileNet backborn BlazeFace(google) to ResNet backborn BlazeFace(original).
Distillation sturacture
Use KL divergence loss and softmax with temperature.
def kl_divergence_loss(logits, target):
T = 0.01
alpha = 0.6
thresh = 100
criterion = nn.MSELoss()
# c : preprocess for distillation
log2div = logits[1].clamp(-thresh, thresh).sigmoid().squeeze(dim=-1)
tar2div = target[1].clamp(-thresh, thresh).sigmoid().squeeze(dim=-1).detach()
closs = nn.KLDivLoss(reduction="batchmean")(F.log_softmax((log2div / T), dim = 1), F.softmax((tar2div / T), dim = 1))*(alpha * T * T) + F.binary_cross_entropy(log2div, tar2div) * (1-alpha)
# r
anchor = load_anchors("src/anchors.npy")
rlogits = decode_boxes(logits[0], anchor)
rtarget = decode_boxes(target[0], anchor)
rloss = criterion(rlogits, rtarget)
return closs + rloss
Distillation perfomance (Resnet backborn BlazeFace)
1.MobileNet-backborn (google pretrained tflite model)
2.Resnet-backborn(Distillation customized Blazeface)
training log
1.loss curve
2 output mae accuracy