lucidrains/res-mlp-pytorch

torch dataset example

Closed this issue · 3 comments

I wrote this examples with a data loader:

import os
import natsort
from PIL import Image
import torch
import torchvision.transforms as T
from res_mlp_pytorch.res_mlp_pytorch import ResMLP

class LPCustomDataSet(torch.utils.data.Dataset):
    '''
        Naive Torch Image Dataset Loader
        with support for Image loading errors
        and Image resizing
    '''
    def __init__(self, main_dir, transform):
        self.main_dir = main_dir
        self.transform = transform
        all_imgs = os.listdir(main_dir)
        self.total_imgs = natsort.natsorted(all_imgs)

    def __len__(self):
        return len(self.total_imgs)

    def __getitem__(self, idx):
        img_loc = os.path.join(self.main_dir, self.total_imgs[idx])
        try:
            image = Image.open(img_loc).convert("RGB")
            tensor_image = self.transform(image)
            return tensor_image
        except:
            pass
            return None

    @classmethod
    def collate_fn(self, batch):
        '''
            Collate filtering not None images
        '''
        batch = list(filter(lambda x: x is not None, batch))
        return torch.utils.data.dataloader.default_collate(batch)

    @classmethod
    def transform(self,img):
        '''
            Naive image resizer
        '''
        transform = T.Compose([
            T.Resize(256),
            T.CenterCrop(224),
            T.ToTensor(),
            T.Normalize(
                mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225]
            )
        ])
        return transform(img)

to feed ResMLP:

model = ResMLP(
    image_size = 256,
    patch_size = 16,
    dim = 512,
    depth = 12,
    num_classes = 1000
)
batch_size = 2
my_dataset = LPCustomDataSet(os.path.join(os.path.dirname(
    os.path.abspath(__file__)), 'data'), transform=LPCustomDataSet.transform)
train_loader = torch.utils.data.DataLoader(my_dataset , batch_size=batch_size, shuffle=False, 
                               num_workers=4, drop_last=True, collate_fn=LPCustomDataSet.collate_fn)
for idx, img in enumerate(train_loader):
    pred = model(img) # (1, 1000)
    print(idx, img.shape, pred.shape

But I get this error

RuntimeError: Given groups=1, weight of size [256, 256, 1], expected input[1, 196, 512] to have 256 channels, but got 196 channels instead

not sure if LPCustomDataSet.transform has the correct for the input image

@loretoparisi Try setting your image size to 224 instead

@lucidrains I have tried to set the transform to

transform = T.Compose([
            T.Resize(224),
            T.CenterCrop(224),
            T.ToTensor(),
            T.Normalize(
                mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225]
            )
        ])

the image shape's

0 torch.Size([1, 3, 224, 224])

but still getting

RuntimeError: Given groups=1, weight of size [256, 256, 1], expected input[1, 196, 512] to have 256 channels, but got 196 channels instead

ah wait, it works ok now forget to change

# Res MLP
res_model = ResMLP(
    image_size = 224,
    patch_size = 16,
    dim = 512,
    depth = 12,
    num_classes = 1000
)