Lightning-Universe/lightning-bolts

AsynchronousLoader deteriorates performance

Pedrexus opened this issue ยท 1 comments

๐Ÿ› Bug

When using AsynchronousLoader , the test performance of training deteriorates sometimes. I could not yet find the root cause of this problem, so I wonder if someone knows how to fix this and could potentially provide a fix.

I am using the AsynchronousLoader as a callback, but I'm not sure this is one of the problems or not. Please help!

To Reproduce

I provide a script below to try and compare, just add/remove the AsyncDataloading callback from the Trainer callbacks list.

My results:

  1. No AsyncDataloading: test_acc = 0.9056, test_loss = 0.3065, t = 4m 32.2s
  2. With AsyncDataloading: test_acc = 0.7591, test_loss = 0.8348, t = 1m 53.8s
    test_loss almost tripled although runtime was halved (which is why I would like to use the AsyncDataloading)

Code sample

import os

import torch
import torchvision
from pl_bolts.datamodules import CIFAR10DataModule
from pl_bolts.transforms.dataset_normalizations import cifar10_normalization
from pytorch_lightning import LightningModule, Trainer, seed_everything
from torch.optim.lr_scheduler import OneCycleLR
from torchmetrics.functional import accuracy
from torchvision.models import ResNet
from torchvision.models.resnet import BasicBlock, Bottleneck
from pytorch_lightning.callbacks import Callback
from pl_bolts.datamodules.async_dataloader import AsynchronousLoader


seed_everything(1234, workers=True)

PATH_DATASETS = "datasets/"
BATCH_SIZE = 256
NUM_WORKERS = os.cpu_count()


train_transforms = torchvision.transforms.Compose(
    [
        torchvision.transforms.RandomCrop(32, padding=4),
        torchvision.transforms.RandomHorizontalFlip(),
        torchvision.transforms.ToTensor(),
        cifar10_normalization(),
    ]
)

cifar10_dm = CIFAR10DataModule(
    data_dir=PATH_DATASETS,
    batch_size=BATCH_SIZE,
    num_workers=NUM_WORKERS,
    train_transforms=train_transforms,
    seed=1234,
    val_split=0.2,
    normalize=True,
    shuffle=True,
    pin_memory=True,
    drop_last=False,
)


def create_model():
    model = ResNet(BasicBlock, [1, 1, 1, 1], num_classes=10)
    model.conv1 = torch.nn.Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    model.maxpool = torch.nn.Identity()
    return model


class LitResnet(LightningModule):
    def __init__(self, lr=0.05):
        super().__init__()

        self.save_hyperparameters()
        self.model = create_model()
        self.loss = torch.nn.CrossEntropyLoss()

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = self.loss(logits, y)
        self.log("train_loss", loss)
        return loss

    def evaluate(self, batch, stage=None):
        x, y = batch
        logits = self(x)
        loss = self.loss(logits, y)
        preds = torch.argmax(logits, dim=1)
        acc = accuracy(preds, y, task="multiclass", num_classes=10)

        if stage:
            self.log(f"{stage}_loss", loss, prog_bar=True)
            self.log(f"{stage}_acc", acc, prog_bar=True)

    def validation_step(self, batch, batch_idx):
        self.evaluate(batch, "val")

    def test_step(self, batch, batch_idx):
        self.evaluate(batch, "test")

    def configure_optimizers(self):
        optimizer = torch.optim.SGD(
            self.parameters(),
            lr=self.hparams.lr,
            momentum=0.9,
            weight_decay=5e-4,
        )
        steps_per_epoch = 45000 // BATCH_SIZE
        scheduler_dict = {
            "scheduler": OneCycleLR(
                optimizer,
                0.1,
                epochs=self.trainer.max_epochs,
                steps_per_epoch=steps_per_epoch,
                # verbose=True,
            ),
            "interval": "step",
            "frequency": 1,
            "monitor": "validation_loss",
            "strict": False,
            "name": None,
        }
        return {"optimizer": optimizer, "lr_scheduler": scheduler_dict}

class AsyncDataloading(Callback):

    def on_train_start(self, trainer: Trainer, pl_module: LightningModule) -> None:
        if trainer.num_devices > 1:
            return

        device = torch.device("cuda", trainer.device_ids[0])
        kw = dict(device=device, q_size=10, num_batches=None)

        trainer.train_dataloader.loaders = AsynchronousLoader(trainer.train_dataloader.loaders, **kw)
        trainer.val_dataloaders = [AsynchronousLoader(loader, **kw) for loader in trainer.val_dataloaders]

trainer = Trainer(
    max_epochs=30,
    accelerator="gpu",
    precision=16,
    amp_backend="native",
    devices=[2],
    callbacks=[AsyncDataloading()],
    num_nodes=1,
    auto_select_gpus=True,
    strategy="dp",
    amp_level=None,
    sync_batchnorm=False,
    profiler=None,
    benchmark=False,
    deterministic=False,
    detect_anomaly=False,
    auto_lr_find=True,
    auto_scale_batch_size="power",
    default_root_dir=".",
    log_every_n_steps=50,
    check_val_every_n_epoch=1,
    num_sanity_val_steps=2,
    gradient_clip_val=1,
    gradient_clip_algorithm="value",
    enable_progress_bar=True,
    enable_model_summary=True,
    enable_checkpointing=True,
)

model = LitResnet(lr=0.05)

trainer.fit(model, cifar10_dm)
trainer.test(model, datamodule=cifar10_dm)

Expected behavior

AsyncDataloader does not affect model performance.

Environment

  • Pytorch and libs installed from Docker image nvcr.io/nvidia/pytorch:22.10-py3
  • GPU models and configuration: 1x NVIDIA RTX A5000 24Gb

Additional info

I also tried customizing the datamodule into a new one, but I still see a similar deterioration in performance:
test_acc = 0.7943, test_loss = 0.6380, t = 1m 59.2s

class AsyncCIFAR10DataModule(CIFAR10DataModule):
    def _data_loader(self, dataset, shuffle: bool = False) -> AsynchronousLoader:
            return AsynchronousLoader(super()._data_loader(dataset, shuffle))

cifar10_dm = AsyncCIFAR10DataModule(
    data_dir=PATH_DATASETS,
    batch_size=BATCH_SIZE,
    num_workers=NUM_WORKERS,
    train_transforms=train_transforms,
    seed=1234,
    val_split=0.2,
    normalize=True,
    shuffle=True,
    pin_memory=True,
    drop_last=False,
)

I was able to fix the performance issue by setting reload_dataloaders_every_n_epochs=1 in the Trainer initialization, but the script time increased to 3m34s.