NVIDIA/DALI

DALI is slower than PyTorch DataLoader only when encountering large dataset

jackdaw213 opened this issue · 11 comments

Describe the question.

class StyleDataset(torch.utils.data.Dataset):
    def __init__(self, content_dir, style_dir):
        self.content = os.listdir(content_dir)
        self.style = os.listdir(style_dir)
        self.pair = list(zip(self.content, self.style))

        self.content_dir = content_dir
        self.style_dir = style_dir

        self.transform = transforms.Compose([
            transforms.Resize((512), antialias=True),
            transforms.RandomCrop(256),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ])

    def __len__(self):
        return len(self.pair)

    def __getitem__(self, index):
        content, style = self.pair[index]

        content = os.path.join(self.content_dir, content)
        style = os.path.join(self.style_dir, style)

        content = Image.open(content).convert("RGB")
        style = Image.open(style).convert("RGB")

        content = self.transform(content)
        style = self.transform(style)

        return content, style
    
    @staticmethod
    @pipeline_def(device_id=0)
    def dali_pipeline(content_dir, style_dir):
        content_images, _ = fn.readers.file(file_root=content_dir, 
                                            files=utils.list_images(content_dir),
                                            random_shuffle=True, 
                                            name="Reader")
        
        style_images, _ = fn.readers.file(file_root=style_dir, 
                                            files=utils.list_images(style_dir),
                                            random_shuffle=True)
        
        content_images = fn.decoders.image(content_images, device="mixed", output_type=types.RGB)
        style_images = fn.decoders.image(style_images, device="mixed", output_type=types.RGB)

        content_images = fn.resize(content_images, size=512, dtype=types.FLOAT)
        style_images = fn.resize(style_images, size=512, dtype=types.FLOAT)

        content_images = fn.crop_mirror_normalize(content_images, 
                                                dtype=types.FLOAT,
                                                crop=(256, 256),
                                                crop_pos_x=fn.random.uniform(range=(0, 1)),
                                                crop_pos_y=fn.random.uniform(range=(0, 1)),
                                                mean=[0.485 * 255, 0.456 * 255, 0.406 * 255],
                                                std=[0.229 * 255, 0.224 * 255, 0.225 * 255])
        style_images = fn.crop_mirror_normalize(style_images, 
                                                dtype=types.FLOAT,
                                                crop=(256, 256),
                                                crop_pos_x=fn.random.uniform(range=(0, 1)),
                                                crop_pos_y=fn.random.uniform(range=(0, 1)),
                                                mean=[0.485 * 255, 0.456 * 255, 0.406 * 255],
                                                std=[0.229 * 255, 0.224 * 255, 0.225 * 255])

       return content_images, style_images

The above is my code, along with the measured time for running the training loop through the entire COCO 2017 dataset and for running a specific number of batches (batch size is 4 - num_workers/num_threads is 4).

Batch num DataLoader DALI Speed up
1 0:00:12.213287 0:00:00.243510 51x
100 0:00:25.307245 0:00:12.591142 2x
1000 0:02:18.732055 0:02:07.409853 1,06x
5000 0:10:49.306911 0:10:47.580684 ~1x
Entire dataset 4 hours 5.5 hours 0.72x

Why is the DataLoader faster than DALI when training with the entire dataset, despite being 50 times slower when loading just one batch? I observed that as the number of batches increases, the performance gap between DALI and the DataLoader narrows until the DataLoader eventually surpasses DALI. This pattern holds whether or not the forward/backward pass and gradient descent are included in the testing.

Check for duplicates

  • I have searched the open bugs/issues and have found no duplicates for this bug report

Hi @jackdaw213,

Thank you for running a performance comparison between DALI and the PyTorch data loader. As I don't know your particular configuration it may happen that the read for the first batches is cached in the disc cache and DALI has an advantage in using the cached data. Also, DALI's native readers use a single tread to access data which works well for local data, while for remote storage one thread is not sufficient to saturate the IO, and the PyTorch data loader can use multiple CPU threads in parallel for data access can sometimes win. You can check if the parallel external source approach can provide better performance.

Hello @JanuszL

As I don't know your particular configuration

All of my data sits in an M2 SSD, please let me know if you need anything else

it may happen that the read for the first batches is cached in the disc cache and DALI has an advantage in using the cached data

Thanks for the info. However, why does DALI gradually lose out to DataLoader as the batch number increases? DALI is faster and has an edge with the first few batches, so it should be able to maintain this advantage, right?

Also, DALI's native readers use a single tread to access data which works well for local data, while for remote storage one thread is not sufficient to saturate the IO, and the PyTorch data loader can use multiple CPU threads in parallel for data access can sometimes win. You can check if the parallel external source approach can provide better performance.

class ExternalInputCallable:
    def __init__(self, batch_size, content_dir, style_dir):
        self.content = os.listdir(content_dir)
        self.style = os.listdir(style_dir)

        self.pair = list(zip(self.content, self.style))

        self.content_dir = content_dir
        self.style_dir = style_dir

        self.full_iterations = len(self.pair) // batch_size

    def __call__(self, sample_info):
        sample_idx = sample_info.idx_in_epoch

        if sample_info.iteration >= self.full_iterations:
            # Indicate end of the epoch
            raise StopIteration()
        
        content, style = self.pair[sample_idx]

        content = io.read_image(os.path.join(self.content_dir, content), mode=io.ImageReadMode.RGB).permute(1,2,0)
        style = io.read_image(os.path.join(self.style_dir, style), mode=io.ImageReadMode.RGB).permute(1,2,0)
        
        return content.numpy(), style.numpy()
        
@staticmethod
@pipeline_def(device_id=0, batch_size=4, py_num_workers=4, py_start_method="spawn")
def dali_pipeline(content_dir, style_dir):

    content_images, style_images = fn.external_source(
        source=exinput.ExternalInputCallable(4, content_dir, style_dir), 
        num_outputs=2,
        parallel=True, batch=False,
    )

    content_images = fn.resize(content_images.gpu(), size=512, dtype=types.FLOAT)
    style_images = fn.resize(style_images.gpu(), size=512, dtype=types.FLOAT)


    content_images = fn.crop_mirror_normalize(content_images, 
                                            dtype=types.FLOAT,
                                            crop=(256, 256),
                                            crop_pos_x=fn.random.uniform(range=(0, 1)),
                                            crop_pos_y=fn.random.uniform(range=(0, 1)),
                                            mean=[0.485 * 255, 0.456 * 255, 0.406 * 255],
                                            std=[0.229 * 255, 0.224 * 255, 0.225 * 255])
    style_images = fn.crop_mirror_normalize(style_images, 
                                            dtype=types.FLOAT,
                                            crop=(256, 256),
                                            crop_pos_x=fn.random.uniform(range=(0, 1)),
                                            crop_pos_y=fn.random.uniform(range=(0, 1)),
                                            mean=[0.485 * 255, 0.456 * 255, 0.406 * 255],
                                            std=[0.229 * 255, 0.224 * 255, 0.225 * 255])

    return content_images, style_images

I tried an external source with 4 py_num_workers but it does not provide any speed-up

Batch num DataLoader DALI DALI External Source
100 0:00:25.307245 0:00:12.591142 0:00:24.723208 # Probably because of overhead
1000 0:02:18.732055 0:02:07.409853 0:02:07.996955
5000 0:10:49.306911 0:10:47.580684 0:10:50.747156

Based on the symptoms and the numbers you shared I convenient that read for the first batches is cached in the disc cache so the processing time is the speed limiter at the very beginning as the read time for this data is negligible. However, you cannot keep the whole data set in the RAM and when the data needs to be accessed directly from the storage the IO dominates. In this case, the processing is as fast as IO which in both cases is the same so that is why you see no difference. What you can do is to use more RAM so the data can go to the disc cache, or use more nodes for training so each GPU works on a smaller piece of the data set (shard) which can be cached more easily.

What could cause DALI to fall behind DataLoader when training with the entire dataset? They should theoretically have the same training time, but DALI is 1.5 hours slower than DataLoader.

Use more nodes for training so each GPU works on a smaller piece of the data set (shard) which can be cached more easily

I only have a single GPU in my system so that is probably not the solution

I'm still interested to see if DALI External Source makes any difference in the long run, for 5000 it is very comparable.

It seems that you are correct, DALI ES is 3 minutes faster than DataLoader when running through the entire batch. Also I just noticed that DALI supports GPU Direct Storage but why it's only available for .npy reader only ? Is it because CUDA can't decode jpeg without special hardware found on A100 ?

Is it because CUDA can't decode jpeg without special hardware found on A100 ?

It is because in most cases the data needs to be, even partially, parsed on the CPU. Even for the JPEG decoding we need to check if the data stream is correct before we attempt to decode it using special hardware. In the case of numpy, we just load the raw data to the GPU memory, and only part of the file (header) needs to be parsed on the CPU.

That's good to know, I will try to figure out a way to minimize IO bottleneck. I did not know that it could be that important

That's good to know, I will try to figure out a way to minimize IO bottleneck. I did not know that it could be that important

You may try to put your data into other formats like TFRecord, RecordIO, or Webdataset and see if files packed together provide better IO speed.

I will look into those, thank you for the suggestion.
Also, is there a way to pass the batch_size argument to the External Input source directly? I can think of making dali_pipeline take batch_size as an argument, but it would be nice not to have to pass the batch_size twice.

train_loader = DALIGenericIterator(
            [dataset.StyleDataset.dali_pipeline(content_dir=train_dir_content,
                                                style_dir=train_dir_style,
                                                batch_size=batch_size, ### How to get this arg
                                                num_threads=4)],
            ['content', 'style'],
        )
@staticmethod
    @pipeline_def(device_id=0, py_num_workers=4, py_start_method="spawn")
    def dali_pipeline(content_dir, style_dir):
        content_images, style_images = fn.external_source(
            source=exinput.ExternalInputCallable(content_dir, style_dir, batch_size), ### and give it to the source ?
            num_outputs=2,
            parallel=True, 
            batch=False
        )

@jackdaw213,

Thank you for your suggestion. It is really plausible to have such a thing. However currently one place defines this for the processing pipeline and in the second for the python data generation (ExternalInputCallable). In the case of ExternalInputCallable it is optional, as it may generate any number of samples in some cases (and provide batch size variability iteration to iteration).
We will discuss how to make it more convenient for the user in the future releases.