pytorch/data

torchdata has a very low accuracy

Closed this issue ยท 0 comments

jc-su commented

๐Ÿ› Describe the bug

When I used torchdata for CIFAR-10 training I got very low accuracy compared to vanilla dataloader. after 100epoch training with resnet18 in torchdata I only got about 60% accuracy, but vanilla dataloader was easy It reached 80%+.

class Dataset:
    def __init__(self, root: Union[str, pathlib.Path], batch_size: int):
        self.root = root
        self.batch_size = batch_size

    def _datapipe(self):
        pass

    def load(self, train=True):
        if train:
            return self._train_datapipe()
        else:
            return self._test_datapipe()


class CIFAR10(Dataset):
    def __init__(self, root: Union[str, pathlib.Path], batch_size: int):
        self.train_path = os.path.join(root, "cifar10/train")
        self.test_path = os.path.join(root, "cifar10/test")
        self.batch_size = batch_size

    def _apply_transform(self, data, train=True):
        if train:
            image_transforms = transforms.Compose([
                transforms.TrivialAugmentWide(
                    interpolation=transforms.InterpolationMode.BILINEAR),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize((0.4914, 0.4822, 0.4465),
                                     (0.2023, 0.1994, 0.2010))
            ])
        else:
            image_transforms = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.4914, 0.4822, 0.4465),
                                     (0.2023, 0.1994, 0.2010))
            ])
        data["image"] = image_transforms(data["image"])
        return data

    def _decode(self, item):
        key, value = item
        decoder = imagehandler("pil")
        image = decoder("jpg", value.read())
        label = os.path.basename(key).split(".")[0].split("_")[-1]
        return {"image": image, "label": int(label), "full_path": key}

    def _collate_fn(self, batch):
        if isinstance(batch[0]['image'], torch.Tensor):
            labels = torch.Tensor([sample['label']
                                    for sample in batch]).to(torch.long)
            images = torch.stack([sample['image'] for sample in batch])
            full_path = [sample['full_path'] for sample in batch]
        return images, labels, full_path

    def _train_datapipe(self):
        train_dp = FileLister(self.train_path, recursive=True).shuffle().sharding_filter().open_files(
            mode="rb")

        train_dp = train_dp.map(self._decode).map(self._apply_transform)
        train_dp = train_dp.batch(self.batch_size, drop_last=True).collate(self._collate_fn)

        return train_dp

    def _test_datapipe(self):
        test_dp = FileLister(self.test_path, recursive=True).shuffle().sharding_filter().open_files(
            mode="rb")

        test_dp = test_dp.map(self._decode).map(self._apply_transform)
        test_dp = test_dp.batch(self.batch_size).collate(self._collate_fn)

        return test_dp

def get_datapipe(root, batch_size, train=True):
    dp = CIFAR10(root, batch_size).load(train)
    return dp


def train(train_loader, optimizer, model, criterion, device):
    model.train()
    for i, (data, target, data_id) in enumerate(tqdm(train_loader)):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()



def test(test_loader, model, criterion, device):
    model.eval()
    test_loss = 0
    correct = 0
    total_len = 0
    with torch.no_grad():
        for i, (data, target, _) in enumerate(tqdm(test_loader)):
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += criterion(output, target).item() * data.size(0)
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()
            total_len += data.size(0)
    test_loss /= total_len
    test_acc = 100. * correct / total_len
    return test_loss, test_acc

root = "dataset"
train_dp = get_datapipe(
    root=root, batch_size=BATCH_SIZE, train=True)
test_dp = get_datapipe(
    root=root, batch_size=BATCH_SIZE, train=False)

model = models.resnet18(weights=None)
model.fc = nn.Linear(model.fc.in_features, 10)
model = model.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

max_process_num = int(subprocess.check_output("nproc"))
rs = MultiProcessingReadingService(num_workers=1)
train_loader = DataLoader2(
            train_dp,
            reading_service=rs,
        )
test_loader = DataLoader2(
            test_dp,
            reading_service=rs,
        )

for epoch in range(EPOCHS):
    train_loader.seed(epoch)
    test_loader.seed(epoch)
    train(train_loader, optimizer, model, criterion, device)
    test_loss, acc = test(test_loader, model, criterion, device)
    

Vanilla: Test set: Average loss: 0.6406, Accuracy: 8090/10000 (81%) in epoch 60
Torchdata: Accuracy: 60% in epoch 100


### Versions

PyTorch version: 2.0.1
Is debug build: False
CUDA used to build PyTorch: 11.8
ROCM used to build PyTorch: N/A

OS: Debian GNU/Linux 11 (bullseye) (x86_64)
GCC version: (Debian 10.2.1-6) 10.2.1 20210110
Clang version: Could not collect
CMake version: version 3.26.3
Libc version: glibc-2.31

Python version: 3.9.16 (main, Mar  8 2023, 14:00:05)  [GCC 11.2.0] (64-bit runtime)
Python platform: Linux-5.15.90.1-microsoft-standard-WSL2-x86_64-with-glibc2.31
Is CUDA available: True
CUDA runtime version: 12.1.105
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: GPU 0: NVIDIA GeForce RTX 3060
Nvidia driver version: 527.41
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture:                    x86_64
CPU op-mode(s):                  32-bit, 64-bit
Byte Order:                      Little Endian
Address sizes:                   39 bits physical, 48 bits virtual
CPU(s):                          20
On-line CPU(s) list:             0-19
Thread(s) per core:              2
Core(s) per socket:              10
Socket(s):                       1
Vendor ID:                       GenuineIntel
CPU family:                      6
Model:                           151
Model name:                      12th Gen Intel(R) Core(TM) i7-12700F
Stepping:                        2
CPU MHz:                         2112.000
BogoMIPS:                        4224.00
Virtualization:                  VT-x
Hypervisor vendor:               Microsoft
Virtualization type:             full
L1d cache:                       480 KiB
L1i cache:                       320 KiB
L2 cache:                        12.5 MiB
L3 cache:                        25 MiB
Vulnerability Itlb multihit:     Not affected
Vulnerability L1tf:              Not affected
Vulnerability Mds:               Not affected
Vulnerability Meltdown:          Not affected
Vulnerability Mmio stale data:   Not affected
Vulnerability Retbleed:          Mitigation; Enhanced IBRS
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl and seccomp
Vulnerability Spectre v1:        Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2:        Mitigation; Enhanced IBRS, IBPB conditional, RSB filling, PBRSB-eIBRS SW sequence
Vulnerability Srbds:             Not affected
Vulnerability Tsx async abort:   Not affected
Flags:                           fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc rep_good nopl xtopology tsc_reliable nonstop_tsc cpuid pni pclmulqdq vmx ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch invpcid_single ssbd ibrs ibpb stibp ibrs_enhanced tpr_shadow vnmi ept vpid ept_ad fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid rdseed adx smap clflushopt clwb sha_ni xsaveopt xsavec xgetbv1 xsaves avx_vnni umip waitpkg gfni vaes vpclmulqdq rdpid movdiri movdir64b fsrm serialize flush_l1d arch_capabilities

Versions of relevant libraries:
[pip3] numpy==1.23.5
[pip3] torch==2.0.1
[pip3] torchaudio==2.0.2
[pip3] torchdata==0.6.1
[pip3] torchtext==0.15.1
[pip3] torchvision==0.15.2
[pip3] triton==2.0.0
[conda] blas                      1.0                         mkl  
[conda] ffmpeg                    4.3                  hf484d3e_0    pytorch
[conda] mkl                       2021.4.0           h06a4308_640  
[conda] mkl-service               2.4.0            py39h7f8727e_0  
[conda] mkl_fft                   1.3.1            py39hd3c417c_0  
[conda] mkl_random                1.2.2            py39h51133e4_0  
[conda] numpy                     1.23.5           py39h14f4228_0  
[conda] numpy-base                1.23.5           py39h31eccc5_0  
[conda] pytorch                   2.0.1           py3.9_cuda11.8_cudnn8.7.0_0    pytorch
[conda] pytorch-cuda              11.8                 h7e8668a_5    pytorch
[conda] pytorch-mutex             1.0                        cuda    pytorch
[conda] torch                     2.0.0                    pypi_0    pypi
[conda] torchaudio                2.0.2                py39_cu118    pytorch
[conda] torchdata                 0.6.1                    pypi_0    pypi
[conda] torchtext                 0.15.1                   pypi_0    pypi
[conda] torchtriton               2.0.0                      py39    pytorch
[conda] torchvision               0.15.2               py39_cu118    pytorch
[conda] triton                    2.0.0                    pypi_0    pypi