torchdata has a very low accuracy
Closed this issue ยท 0 comments
jc-su commented
๐ Describe the bug
When I used torchdata for CIFAR-10 training I got very low accuracy compared to vanilla dataloader. after 100epoch training with resnet18 in torchdata I only got about 60% accuracy, but vanilla dataloader was easy It reached 80%+.
class Dataset:
def __init__(self, root: Union[str, pathlib.Path], batch_size: int):
self.root = root
self.batch_size = batch_size
def _datapipe(self):
pass
def load(self, train=True):
if train:
return self._train_datapipe()
else:
return self._test_datapipe()
class CIFAR10(Dataset):
def __init__(self, root: Union[str, pathlib.Path], batch_size: int):
self.train_path = os.path.join(root, "cifar10/train")
self.test_path = os.path.join(root, "cifar10/test")
self.batch_size = batch_size
def _apply_transform(self, data, train=True):
if train:
image_transforms = transforms.Compose([
transforms.TrivialAugmentWide(
interpolation=transforms.InterpolationMode.BILINEAR),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465),
(0.2023, 0.1994, 0.2010))
])
else:
image_transforms = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465),
(0.2023, 0.1994, 0.2010))
])
data["image"] = image_transforms(data["image"])
return data
def _decode(self, item):
key, value = item
decoder = imagehandler("pil")
image = decoder("jpg", value.read())
label = os.path.basename(key).split(".")[0].split("_")[-1]
return {"image": image, "label": int(label), "full_path": key}
def _collate_fn(self, batch):
if isinstance(batch[0]['image'], torch.Tensor):
labels = torch.Tensor([sample['label']
for sample in batch]).to(torch.long)
images = torch.stack([sample['image'] for sample in batch])
full_path = [sample['full_path'] for sample in batch]
return images, labels, full_path
def _train_datapipe(self):
train_dp = FileLister(self.train_path, recursive=True).shuffle().sharding_filter().open_files(
mode="rb")
train_dp = train_dp.map(self._decode).map(self._apply_transform)
train_dp = train_dp.batch(self.batch_size, drop_last=True).collate(self._collate_fn)
return train_dp
def _test_datapipe(self):
test_dp = FileLister(self.test_path, recursive=True).shuffle().sharding_filter().open_files(
mode="rb")
test_dp = test_dp.map(self._decode).map(self._apply_transform)
test_dp = test_dp.batch(self.batch_size).collate(self._collate_fn)
return test_dp
def get_datapipe(root, batch_size, train=True):
dp = CIFAR10(root, batch_size).load(train)
return dp
def train(train_loader, optimizer, model, criterion, device):
model.train()
for i, (data, target, data_id) in enumerate(tqdm(train_loader)):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()
def test(test_loader, model, criterion, device):
model.eval()
test_loss = 0
correct = 0
total_len = 0
with torch.no_grad():
for i, (data, target, _) in enumerate(tqdm(test_loader)):
data, target = data.to(device), target.to(device)
output = model(data)
test_loss += criterion(output, target).item() * data.size(0)
pred = output.argmax(dim=1, keepdim=True)
correct += pred.eq(target.view_as(pred)).sum().item()
total_len += data.size(0)
test_loss /= total_len
test_acc = 100. * correct / total_len
return test_loss, test_acc
root = "dataset"
train_dp = get_datapipe(
root=root, batch_size=BATCH_SIZE, train=True)
test_dp = get_datapipe(
root=root, batch_size=BATCH_SIZE, train=False)
model = models.resnet18(weights=None)
model.fc = nn.Linear(model.fc.in_features, 10)
model = model.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
max_process_num = int(subprocess.check_output("nproc"))
rs = MultiProcessingReadingService(num_workers=1)
train_loader = DataLoader2(
train_dp,
reading_service=rs,
)
test_loader = DataLoader2(
test_dp,
reading_service=rs,
)
for epoch in range(EPOCHS):
train_loader.seed(epoch)
test_loader.seed(epoch)
train(train_loader, optimizer, model, criterion, device)
test_loss, acc = test(test_loader, model, criterion, device)
Vanilla: Test set: Average loss: 0.6406, Accuracy: 8090/10000 (81%) in epoch 60
Torchdata: Accuracy: 60% in epoch 100
### Versions
PyTorch version: 2.0.1
Is debug build: False
CUDA used to build PyTorch: 11.8
ROCM used to build PyTorch: N/A
OS: Debian GNU/Linux 11 (bullseye) (x86_64)
GCC version: (Debian 10.2.1-6) 10.2.1 20210110
Clang version: Could not collect
CMake version: version 3.26.3
Libc version: glibc-2.31
Python version: 3.9.16 (main, Mar 8 2023, 14:00:05) [GCC 11.2.0] (64-bit runtime)
Python platform: Linux-5.15.90.1-microsoft-standard-WSL2-x86_64-with-glibc2.31
Is CUDA available: True
CUDA runtime version: 12.1.105
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: GPU 0: NVIDIA GeForce RTX 3060
Nvidia driver version: 527.41
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Byte Order: Little Endian
Address sizes: 39 bits physical, 48 bits virtual
CPU(s): 20
On-line CPU(s) list: 0-19
Thread(s) per core: 2
Core(s) per socket: 10
Socket(s): 1
Vendor ID: GenuineIntel
CPU family: 6
Model: 151
Model name: 12th Gen Intel(R) Core(TM) i7-12700F
Stepping: 2
CPU MHz: 2112.000
BogoMIPS: 4224.00
Virtualization: VT-x
Hypervisor vendor: Microsoft
Virtualization type: full
L1d cache: 480 KiB
L1i cache: 320 KiB
L2 cache: 12.5 MiB
L3 cache: 25 MiB
Vulnerability Itlb multihit: Not affected
Vulnerability L1tf: Not affected
Vulnerability Mds: Not affected
Vulnerability Meltdown: Not affected
Vulnerability Mmio stale data: Not affected
Vulnerability Retbleed: Mitigation; Enhanced IBRS
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl and seccomp
Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2: Mitigation; Enhanced IBRS, IBPB conditional, RSB filling, PBRSB-eIBRS SW sequence
Vulnerability Srbds: Not affected
Vulnerability Tsx async abort: Not affected
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc rep_good nopl xtopology tsc_reliable nonstop_tsc cpuid pni pclmulqdq vmx ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch invpcid_single ssbd ibrs ibpb stibp ibrs_enhanced tpr_shadow vnmi ept vpid ept_ad fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid rdseed adx smap clflushopt clwb sha_ni xsaveopt xsavec xgetbv1 xsaves avx_vnni umip waitpkg gfni vaes vpclmulqdq rdpid movdiri movdir64b fsrm serialize flush_l1d arch_capabilities
Versions of relevant libraries:
[pip3] numpy==1.23.5
[pip3] torch==2.0.1
[pip3] torchaudio==2.0.2
[pip3] torchdata==0.6.1
[pip3] torchtext==0.15.1
[pip3] torchvision==0.15.2
[pip3] triton==2.0.0
[conda] blas 1.0 mkl
[conda] ffmpeg 4.3 hf484d3e_0 pytorch
[conda] mkl 2021.4.0 h06a4308_640
[conda] mkl-service 2.4.0 py39h7f8727e_0
[conda] mkl_fft 1.3.1 py39hd3c417c_0
[conda] mkl_random 1.2.2 py39h51133e4_0
[conda] numpy 1.23.5 py39h14f4228_0
[conda] numpy-base 1.23.5 py39h31eccc5_0
[conda] pytorch 2.0.1 py3.9_cuda11.8_cudnn8.7.0_0 pytorch
[conda] pytorch-cuda 11.8 h7e8668a_5 pytorch
[conda] pytorch-mutex 1.0 cuda pytorch
[conda] torch 2.0.0 pypi_0 pypi
[conda] torchaudio 2.0.2 py39_cu118 pytorch
[conda] torchdata 0.6.1 pypi_0 pypi
[conda] torchtext 0.15.1 pypi_0 pypi
[conda] torchtriton 2.0.0 py39 pytorch
[conda] torchvision 0.15.2 py39_cu118 pytorch
[conda] triton 2.0.0 pypi_0 pypi