Syencil/mobile-yolov5-pruning-distillation

L1 稀疏化训练细节

xuanyuyt opened this issue · 0 comments

 def compute_pruning_loss(p, prunable_modules, model, loss):
    ft = torch.cuda.FloatTensor if p[0].is_cuda else torch.Tensor
    ll1 = ft([0])
    h = model.hyp  # hyperparameters
    red = 'mean'  # Loss reduction (sum or mean)
    if prunable_modules is not None:
        for m in prunable_modules:
            ll1 += m.weight.norm(1)  # BN 层 gamma 值 的 L1 范数
        ll1 /= len(prunable_modules)
    ll1 *= h['sl']
    bs = p[0].shape[0]  # batch size
    loss += ll1 * bs
    return loss

请教下函数中 平均范数为啥要乘上 batch_size?