YU1ut/MixMatch-pytorch

Performance declining by replacing data augmentation with pyotrch function.

yichuan9527 opened this issue · 3 comments

Hello, I replace your data augmentation function (RandomPadandCrop, RandomFlip) with pytorch augmenation function (RandomCrop, RandomHorizontalFlip). But the performance declines from 94% to 66%. My data augmantion code is as follow:

transform_train = transforms.Compose([ transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2471, 0.2435, 0.2616)) ])

transform_val = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2471, 0.2435, 0.2616)) ])

To run the code, I modify the code of CIFAR10_labeled as follow:
`class CIFAR10_labeled(torchvision.datasets.CIFAR10):

def __init__(self, root, indexs=None, train=True,
             transform=None, target_transform=None,
             download=False):
    super(CIFAR10_labeled, self).__init__(root, train=train,
             transform=transform, target_transform=target_transform,
             download=download)
    if indexs is not None:
        self.data = self.data[indexs]
        self.targets = np.array(self.targets)[indexs]

    # self.data = transpose(normalise(self.data))
    self.data = np.vstack(self.data).reshape(-1, 3, 32, 32)
    # pdb.set_trace()

    self.data = self.data.transpose((0, 2, 3, 1))  # convert to HWC
    # pdb.set_trace()

def __getitem__(self, index):
    """
    Args:
        index (int): Index

    Returns:
        tuple: (image, target) where target is index of the target class.
    """
    img, target = self.data[index], self.targets[index]
    # pdb.set_trace()
    img = Image.fromarray(img)

    if self.transform is not None:
        img = self.transform(img)

    if self.target_transform is not None:
        target = self.target_transform(target)

    return img, target`

It is very strange ! In theory, the implments (your function and pytorch aug fuinction) of the augmentation are the same. But a substantial performance gap becomes apparent.

Hi! Recently, I try to adapt this code for Clothing1M dataset and I carefully check its data augmentation.

It seems that if you use official pytorch augmentation function, these lines of code is not necessary:
self.data = np.vstack(self.data).reshape(-1, 3, 32, 32)
self.data = self.data.transpose((0, 2, 3, 1)) # convert to HWC

And this implementation of MixMatch

  1. first using normalize and transpose ( change a PIL picture to a tensor picture) then pad and crop,
    although the implementation is correct, but is different to official pytorch augmentation( usually first pad and crop then transpose to a tensor picture and normalize)
  2. this implementation using padding(mode = reflect), may be you need to change the default padding_mode in official pytorch augmentation
    torchvision.transforms.RandomCrop(size, padding=None, pad_if_needed=False, fill=0, padding_mode='constant')

I've setup this up here. It's not tested for the full training time, but after 100 epochs and for 250 labelled examples I get the same score +- 0.5%

@guixianjin
for Clothing1M, did you train with mutiple gpus?