learnables/learn2learn

TypeError: An invalid dataloader was passed to `Trainer.fit(train_dataloaders=...)`. Found <learn2learn.utils.lightning.EpisodicBatcher object at 0x7f3bade9bca0>.

joshuasv opened this issue · 0 comments

When passing an learn2learn.utils.lightning.EpisodicBatcher to a lightning.Trainer the training crashes.

MRE

import sys
import logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

import torch
import lightning
import learn2learn
from torch.utils.data import DataLoader
from torch.nn import Module, Linear
from lightning import LightningDataModule, LightningModule, Trainer
from learn2learn.data.task_dataset import Taskset
from learn2learn.utils.lightning import EpisodicBatcher

logging.warning(sys.version_info)
logging.warning(torch.__version__)
logging.warning(lightning.__version__)
logging.warning(learn2learn.__version__)

class DummyDataset(torch.utils.data.Dataset):
    def __init__(self): super().__init__(); self.data = torch.tensor(list(range(10)), dtype=torch.float32)

    def __len__(self): return len(self.data)

    def __getitem__(self, idx): return self.data[idx], 1

class DummyLightningDataModule(LightningDataModule):
    def __init__(self): super().__init__()

    def train_dataloader(self): return DataLoader(DummyDataset(), batch_size=1)

    def validation_dataloader(self): return DataLoader(DummyDataset(), batch_size=1)

class DummyModel(Module):

    def __init__(self): super().__init__(); self.model = Linear(1, 1)

    def forward(self, x): return self.model(x)

class DummyLightningModule(LightningModule):

    def __init__(self): super().__init__(); self.model = DummyModel()

    def forward(self, x): return self.model(x)

    def training_step(self, batch, batch_idx): return self(batch[0])
    
    def validation_step(self, batch, batch_idx): return self(batch[0])

    def configure_optimizers(self): return torch.optim.Adam(self.parameters(), 0.)

dataset = DummyDataset()
taskset = Taskset(dataset)
datamodule = EpisodicBatcher(train_tasks=taskset)
l_datamodule = DummyLightningDataModule()
model = DummyLightningModule()


logging.warning("Train with dummy lightning classes")
logging.warning(issubclass(l_datamodule.__class__, LightningDataModule))
logging.warning(l_datamodule.__class__.__base__)
t = Trainer(max_epochs=5, enable_model_summary=False, enable_progress_bar=False)
t.fit(model, l_datamodule)

logging.warning("Train with lear2learn classes")
logging.warning(issubclass(datamodule.__class__, LightningDataModule))
logging.warning(datamodule.__class__.__base__)
t = Trainer(max_epochs=5, enable_model_summary=False, enable_progress_bar=False)
t.fit(model, datamodule)

Output

2024-01-25 08:38:42,375 - WARNING - sys.version_info(major=3, minor=10, micro=6, releaselevel='final', serial=0)
2024-01-25 08:38:42,375 - WARNING - 2.1.0a0+b5021ba
2024-01-25 08:38:42,375 - WARNING - 2.1.3
2024-01-25 08:38:42,375 - WARNING - 0.2.0
2024-01-25 08:38:42,376 - WARNING - Train with dummy lightning classes
2024-01-25 08:38:42,376 - WARNING - True
2024-01-25 08:38:42,376 - WARNING - <class 'lightning.pytorch.core.datamodule.LightningDataModule'>
2024-01-25 08:38:42,525 - INFO - GPU available: True (cuda), used: True
2024-01-25 08:38:42,525 - INFO - TPU available: False, using: 0 TPU cores
2024-01-25 08:38:42,525 - INFO - IPU available: False, using: 0 IPUs
2024-01-25 08:38:42,525 - INFO - HPU available: False, using: 0 HPUs
/usr/local/lib/python3.10/dist-packages/lightning/pytorch/trainer/configuration_validator.py:74: You defined a `validation_step` but have no `val_dataloader`. Skipping val loop.
INFO: LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
2024-01-25 08:38:42,665 - INFO - LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
/usr/local/lib/python3.10/dist-packages/lightning/pytorch/trainer/connectors/data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=3` in the `DataLoader` to improve performance.
/usr/local/lib/python3.10/dist-packages/lightning/pytorch/loops/fit_loop.py:293: The number of training batches (10) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.
2024-01-25 08:38:42,873 - INFO - `Trainer.fit` stopped: `max_epochs=5` reached.
2024-01-25 08:38:42,875 - WARNING - Train with lear2learn classes
2024-01-25 08:38:42,875 - WARNING - False
2024-01-25 08:38:42,875 - WARNING - <class 'pytorch_lightning.core.datamodule.LightningDataModule'>
2024-01-25 08:38:42,885 - INFO - GPU available: True (cuda), used: True
2024-01-25 08:38:42,885 - INFO - TPU available: False, using: 0 TPU cores
2024-01-25 08:38:42,885 - INFO - IPU available: False, using: 0 IPUs
2024-01-25 08:38:42,885 - INFO - HPU available: False, using: 0 HPUs
INFO: LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
2024-01-25 08:38:42,887 - INFO - LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/lightning/pytorch/trainer/connectors/data_connector.py", line 402, in _check_dataloader_iterable
    iter(dataloader)  # type: ignore[call-overload]
TypeError: 'EpisodicBatcher' object is not iterable

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/code/test.py", line 68, in <module>
    t.fit(model, datamodule)
  File "/usr/local/lib/python3.10/dist-packages/lightning/pytorch/trainer/trainer.py", line 544, in fit
    call._call_and_handle_interrupt(
  File "/usr/local/lib/python3.10/dist-packages/lightning/pytorch/trainer/call.py", line 44, in _call_and_handle_interrupt
    return trainer_fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/lightning/pytorch/trainer/trainer.py", line 580, in _fit_impl
    self._run(model, ckpt_path=ckpt_path)
  File "/usr/local/lib/python3.10/dist-packages/lightning/pytorch/trainer/trainer.py", line 989, in _run
    results = self._run_stage()
  File "/usr/local/lib/python3.10/dist-packages/lightning/pytorch/trainer/trainer.py", line 1035, in _run_stage
    self.fit_loop.run()
  File "/usr/local/lib/python3.10/dist-packages/lightning/pytorch/loops/fit_loop.py", line 194, in run
    self.setup_data()
  File "/usr/local/lib/python3.10/dist-packages/lightning/pytorch/loops/fit_loop.py", line 237, in setup_data
    _check_dataloader_iterable(dl, source, trainer_fn)
  File "/usr/local/lib/python3.10/dist-packages/lightning/pytorch/trainer/connectors/data_connector.py", line 407, in _check_dataloader_iterable
    raise TypeError(
TypeError: An invalid dataloader was passed to `Trainer.fit(train_dataloaders=...)`. Found <learn2learn.utils.lightning.EpisodicBatcher object at 0x7f77a32c8640>.