Error while adding CIFAR10 and CIFAR100
ArianKhorasani opened this issue · 3 comments
Dear @krasserm,
I'm currently attempting to integrate CIFAR10 and CIFAR100 datasets into your code to train the perceiver-io model. Following the approach you've taken for MNIST, I've created a cifar10.py
file within perceiver/data/vision
and followed your steps. However, when I attempt to run the training example, I consistently encounter an error stating that 'image' has not been defined.
I have also ensured that I imported CIFAR10DataModule
into the train.py
script located in examples/training/img_clf
. Despite these efforts, I'm still unable to successfully execute your code with CIFAR10 and CIFAR100.
Thank you in advance for your assistance!
Hi @ArianKhorasani, can you provide code for reproducing the problem?
Hi @ArianKhorasani, can you provide code for reproducing the problem?
Dear @krasserm! Here is the cifar10.py file that I added:
import os
from typing import Optional
import pytorch_lightning as pl
from datasets import load_dataset
from torch.utils.data import DataLoader
from torchvision import transforms
from perceiver.data.vision.common import channels_to_last, ImagePreprocessor, lift_transform
class CIFAR10Preprocessor(ImagePreprocessor):
def __init__(self, normalize: bool = True, channels_last: bool = True):
super().__init__(cifar10_transform(normalize, channels_last))
class CIFAR10DataModule(pl.LightningDataModule):
def __init__(
self,
dataset_dir: str = os.path.join(".cache", "cifar10"),
normalize: bool = True,
channels_last: bool = True,
random_crop: Optional[int] = None,
batch_size: int = 64,
num_workers: int = 3,
pin_memory: bool = True,
shuffle: bool = True,
):
super().__init__()
self.save_hyperparameters()
self.channels_last = channels_last
self.tf_train = cifar10_transform(normalize, channels_last, random_crop=random_crop)
self.tf_valid = cifar10_transform(normalize, channels_last, random_crop=None)
self.ds_train = None
self.ds_valid = None
@property
def num_classes(self):
return 10
@property
def image_shape(self):
if self.hparams.channels_last:
return 32, 32, 3
else:
return 3, 32, 32
def load_dataset(self, split: Optional[str] = None):
return load_dataset("cifar10", split=split, cache_dir=self.hparams.dataset_dir)
def prepare_data(self) -> None:
self.load_dataset()
def setup(self, stage: Optional[str] = None) -> None:
self.ds_train = self.load_dataset(split="train")
self.ds_train.set_transform(lift_transform(self.tf_train))
self.ds_valid = self.load_dataset(split="test")
self.ds_valid.set_transform(lift_transform(self.tf_valid))
def train_dataloader(self):
return DataLoader(
self.ds_train,
shuffle=self.hparams.shuffle,
batch_size=self.hparams.batch_size,
num_workers=self.hparams.num_workers,
pin_memory=self.hparams.pin_memory,
)
def val_dataloader(self):
return DataLoader(
self.ds_valid,
shuffle=False,
batch_size=self.hparams.batch_size,
num_workers=self.hparams.num_workers,
pin_memory=self.hparams.pin_memory,
)
def cifar10_transform(normalize: bool = True, channels_last: bool = True, random_crop: Optional[int] = None):
transform_list = []
if random_crop is not None:
transform_list.append(transforms.RandomCrop(random_crop))
transform_list.append(transforms.ToTensor())
if normalize:
transform_list.append(transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)))
if channels_last:
transform_list.append(channels_to_last)
return transforms.Compose(transform_list)
Please note that after adding this file, I did import CIFAR10DataModule and CIFAR10Preprocessor in image_classifier.py file in /perceiver/scripts/vision path!
Would be happy to hear your thoughts on it! Thanks!
You need to rename the image column in the dataset:
def load_dataset(self, split: Optional[str] = None):
return load_dataset("cifar10", split=split, cache_dir=self.hparams.dataset_dir).rename_column("img", "image")