/torch-Distillation-Blazeface

Distillation of Mobilenet base BlazeFace to Resnet base BlazeFace

Primary LanguagePython

torch-BlazefaceDistillation

Abstract

This is Distillation for Converting performance of MobileNet backborn BlazeFace(google) to ResNet backborn BlazeFace(original).

Distillation sturacture

Use KL divergence loss and softmax with temperature.

def kl_divergence_loss(logits, target):
    T = 0.01
    alpha = 0.6
    thresh = 100
    criterion = nn.MSELoss()
    # c : preprocess for distillation
    log2div = logits[1].clamp(-thresh, thresh).sigmoid().squeeze(dim=-1)
    tar2div = target[1].clamp(-thresh, thresh).sigmoid().squeeze(dim=-1).detach()
    closs = nn.KLDivLoss(reduction="batchmean")(F.log_softmax((log2div / T), dim = 1), F.softmax((tar2div / T), dim = 1))*(alpha * T * T) + F.binary_cross_entropy(log2div, tar2div) * (1-alpha)
    
    # r
    anchor = load_anchors("src/anchors.npy")
    rlogits = decode_boxes(logits[0], anchor)
    rtarget = decode_boxes(target[0], anchor)
    rloss = criterion(rlogits, rtarget) 
     
    return closs + rloss

Distillation perfomance (Resnet backborn BlazeFace)

1.MobileNet-backborn (google pretrained tflite model)

2.Resnet-backborn(Distillation customized Blazeface)

training log

1.loss curve

2 output mae accuracy

References

MediaPipePyTorch

tf-blazeface