L1aoXingyu/code-of-learn-deep-learning-with-pytorch

你好,调用你的utils.py里面的train函数,准确率为0

gittigxuy opened this issue · 1 comments

我的运行环境是Python2.7,pytorch0.3.0

问题已经解决了,建议将你的utils.py当中的get_acc改为现在这个
def get_acc(output, label):
total = output.shape[0]
_, pred_label = output.max(1)
num_correct = (pred_label == label).sum().data[0]
return float(num_correct) / total

原来的num_correct为整数类型,除以total之后给截断了,导致准确率为0