longcw/yolo2-pytorch

Upgrade train.py to pytorch 0.4.0

Erotemic opened this issue · 1 comments

The current training script breaks with the new version of pytorch. The fix is to replace lines 88-92 of train.py with:

    if torch.__version__.startswith('0.3'):
        bbox_loss += net.bbox_loss.data.cpu().numpy()[0]
        iou_loss += net.iou_loss.data.cpu().numpy()[0]
        cls_loss += net.cls_loss.data.cpu().numpy()[0]
        train_loss += loss.data.cpu().numpy()[0]
    else:
        bbox_loss += float(net.bbox_loss.data.cpu().numpy())
        iou_loss += float(net.iou_loss.data.cpu().numpy())
        cls_loss += float(net.cls_loss.data.cpu().numpy())
        train_loss += float(loss.data.cpu().numpy())

Or you can get the scalar using net.bbox_loss.item() in pytorch0.4