L1 稀疏化训练细节
xuanyuyt opened this issue · 0 comments
xuanyuyt commented
def compute_pruning_loss(p, prunable_modules, model, loss):
ft = torch.cuda.FloatTensor if p[0].is_cuda else torch.Tensor
ll1 = ft([0])
h = model.hyp # hyperparameters
red = 'mean' # Loss reduction (sum or mean)
if prunable_modules is not None:
for m in prunable_modules:
ll1 += m.weight.norm(1) # BN 层 gamma 值 的 L1 范数
ll1 /= len(prunable_modules)
ll1 *= h['sl']
bs = p[0].shape[0] # batch size
loss += ll1 * bs
return loss
请教下函数中 平均范数为啥要乘上 batch_size?