libffcv/ffcv

Top-1 accuracy on ImageNet drops between runs -- only difference is FFCV

nelaturuharsha opened this issue · 2 comments

Hello,

Training a ResNet50 on ImageNet for a project and noticed the following issues:

  • Top-1 Accuracy drops significantly and the PyTorch DataLoader version is almost always better on the test set.
  • There is a less drastic difference in terms of training-time metrics but there is around 2% drop in accuracy per epoch in comparison.
  • There is little to no speedup due to FFCV on RTXA6000 GPU.

For some context

  • Dataset was generated using the ffcv-imagenet repository and the standard parameters of 500, 0.50, 90. I have provided the dataloader object I have created below to note if there is any issue.
  • Every other component is exactly the same for the two runs
  • In the first epoch of FFCV the time is 38 minutes per epoch, and this does no improve at all. With a PyTorch DataLoader timing is 41 minutes per epoch, barely noticeable.
  • The training was taking place in FP32.
class FFCVImageNet:
    def __init__(self, args):
        super(FFCVImageNet, self).__init__()

        data_root = '../imagenet-data/'

        IMAGENET_MEAN = np.array([0.485, 0.456, 0.406]) * 255
        IMAGENET_STD = np.array([0.229, 0.224, 0.225]) * 255
        DEFAULT_CROP_RATIO = 224/256
        train_image_pipeline = [RandomResizedCropRGBImageDecoder((224, 224)),
                            RandomHorizontalFlip(),
                            ToTensor(),
                            ToDevice(torch.device('cuda:0'), non_blocking=True),
                            ToTorchImage(),
                            NormalizeImage(IMAGENET_MEAN, IMAGENET_STD, np.float32)]

        val_image_pipeline = [CenterCropRGBImageDecoder((256, 256), ratio=DEFAULT_CROP_RATIO),
                              ToTensor(),
                              ToDevice(torch.device('cuda:0'), non_blocking=True),
                              ToTorchImage(),
                              NormalizeImage(IMAGENET_MEAN, IMAGENET_STD, np.float32)]

        label_pipeline = [IntDecoder(),
                            ToTensor(),
                            Squeeze(),
                            ToDevice(torch.device('cuda:0'), non_blocking=True)]



        self.train_loader = Loader(data_root + 'train_500_0.50_90.ffcv',
                              batch_size  = args.batch_size,
                              num_workers = args.workers,
                              order       = OrderOption.QUASI_RANDOM,
                              os_cache    = True,
                              drop_last   = True,
                              pipelines   = { 'image' : train_image_pipeline,
                                              'label' : label_pipeline},
                              )

        self.val_loader = Loader(data_root + 'val_500_0.50_90.ffcv',
                            batch_size  = args.batch_size,
                            num_workers = args.workers,
                            order       = OrderOption.SEQUENTIAL,
                            drop_last   = False,
                            pipelines   = { 'image' : val_image_pipeline,
                                            'label' : label_pipeline},
                            )

image

As you can see above, the performance is far worse when FFCV is used.

Would appreciate any insight into why this is happening and what could be done to improve.

Thanks!

Hi @SreeHarshaNelaturu ! What training code are you using?

Hi @andrewilyas thanks for the prompt response. This a custom harness we wrote on our own for training + pruning networks. Could you let me know any specific aspects I could send across?

Btw, I found one interesting result that seems to have fixed the problem almost -- that is adding "shuffled_indices=True" while regenerating the beton.

The loss/accuracy curves now look like this
Screenshot 2023-10-08 at 6 35 13 PM

Blue - FFCV (w/o shuffle_indices = True while creating the beton)
Orange - PyTorch DataLoader
maroon - FFCV (w/ shuffle indices = True while creating the beton)

Hope this could help people out, I think this is related to usage of OrderOption.QUASI_RANDOM as indicated in issue #304 .

Regarding speed-up

I also tested this on the CelebA dataset, and could provide code to reproduce -- there was little or no speedup achieved due to use of FFCV and there are very little throughput gains using FFCV (38 mins v/s 41 mins) on ImageNet [Device: NVIDIA A6000]

Is this speedup only to be expected on mixed-precision training?

(Please let me know if its better to open a different issue for the speedup related component)

Cheers!