MegEngine/Documentation

基于 module 的统计参数量的示例code微调

xgbj opened this issue · 2 comments

xgbj commented

在章节 https://megengine.org.cn/doc/stable/zh/user-guide/tools/module-stats.html

from megengine.hub import load
from megengine.utils.module_stats import module_stats

# 构建一个 net module,这里从 model hub 中获取 resnet18 模型
net = load("megengine/models", "resnet18", pretrained=True)

# 指定输入 shape
input_shape = (1, 3, 224, 224)

# Float model.
total_params, total_flops = module_stats(
    net, input_shape, log_params=True, log_flops=True
)
print("params {} flops {}".format(total_params, total_flops))

建议把input_shape改成input_shape_list,一般模型多会吃多个输入,像示例这样的写法容易造成误解,或者增加注释

感谢反馈,内部正在改进这块的组织

在最近更新的版本中对这个文档页面内容进行了一定的修改:

https://megengine.org.cn/doc/stable/zh/user-guide/tools/stats.html

请看一看现在的写法是否更有帮助理解,如果你有任何改进的想法,欢迎直接发起 Pull Request.