deep-learning-with-pytorch/dlwpt-code

Confusion regarding Normalized CIFAR10 dataset in Chapter 7

tataganesh opened this issue · 0 comments

In Chapter 7, the CIFAR10 dataset is initially loaded as -

cifar10 = datasets.CIFAR10(data_path, train=True, download=True)

Then, section 7.1.4 discusses the importance of normalizing the data. The transformed CIFAR10 dataset is loaded as -

transformed_cifar10 = datasets.CIFAR10(data_path, train=True, download=False, 
                                      transform=transforms.Compose([
                                      transforms.ToTensor(),
                                      transforms.Normalize((0.4915, 0.4823, 0.4468),
                                      (0.2470, 0.2435, 0.2616))
                                      ]))

However, in section 7.2.1, a dataset consisting of samples with labels 0 and 2 is created using the cifar10 variable.

cifar2 = [(img, label_map[label]) for img, label in cifar10 if label in [0, 2]]

I am assuming that the cifar10 variable here indicates the normalized cifar10 dataset. Hence, would it clearer to replace cifar10 with transformed_cifar10?

cifar2 = [(img, label_map[label]) for img, label in transformed_cifar10 if label in [0, 2]]

This will ensure that someone who is implementing these steps understands that the normalized data is now being used to train the NN.