leoxiaobin/deep-high-resolution-net.pytorch

Bugs in `get_model_summary' and flops calculation

jin-s13 opened this issue · 1 comments

Dear all,

Thanks for releasing the codes!
I noticed some problems in model flops calculation.

  1. The classname Conv2d and ConvTranspose2d both contain 'Conv', so their flops are counted in get_model_summary.
    if class_name.find("Conv") != -1 and hasattr(module, "weight"):

However, the flops-calculation for Conv2d and ConvTranspose2d should be different.

For Conv2d, it is

flops = (torch.prod(torch.LongTensor(list(module.weight.data.size()))) * torch.prod(torch.LongTensor(list(output.size())[2:]))).item()

But for ConvTranspose2d, it should be

flops = (torch.prod(torch.LongTensor(list(module.weight.data.size()))) * torch.prod(torch.LongTensor(list(input[0].size())[2:]))).item()

  1. Flops of many other ops (e.g. BN, ReLU) are not calculated.