deep-learning-with-pytorch/dlwpt-code

Read labels from the module, not from some random file

simsong opened this issue · 0 comments

Chapter 2 uses this code to get the model labels:

with open('../data/p1ch2/imagenet_classes.txt') as f:
    labels = [line.strip() for line in f.readlines()]

However, it is never made clear where the file imagenet_classes.txt comes from.

It turns out that they are already in torchvision.models:

from torchvision import models
labels = models.ResNet101_Weights.DEFAULT.meta['categories']

It would be useful to update the book to make this clear, or at least put it in the errata.