pytorch/data

Iterating a data pipe, created with random split, ends in error as the code tries to iterate past the data pipe lenght

Opened this issue ยท 0 comments

๐Ÿ› Describe the bug

iterating trough a data pipe, generated to a random split iters correctly trough all the data it is supposed to , but unfortunately it does not stop and it ends with an error: ValueError: Total of weights must be greater than zero (even so, weights were correct).

The error is indipendent from the number of rows (from a csv file in my case). if I specify NumRows as 10, and test weight 0.8, the iteration of the example outputs 8 rows and than goes in error, if I specify NumRows as 100, with the same test weight, the program outputs 80 rows before going in error. SO the problem is not the file content either.

a temporary workaround is to calculate the number of lines the datapipe is supposed to return and manually break the for loop at the last one, but that is not really workable in the long run.

I tried also to verify if this is a problem with all datapipes generated from csv file but no, I could iterate trough a "full" datapipe (aka, not generated via a random_split) without problems

This is with pytorch version 2.1.2 and torchdata 0.7.1

Versions

import torchdata.datapipes as dp

from torchdata.datapipes.iter import FileOpener

FileList_dp = dp.iter.IterableWrapper(["dataset/train-processed-sample.csv"])
datapipe = FileOpener(FileList_dp, mode='rt')
datapipe = datapipe.parse_csv(delimiter=',', skip_lines=1)

NumRows=10
train_dataset, test_dataset = datapipe.random_split(total_length=NumRows, weights={"train": 0.8, "test": 0.2 }, seed=42)

for sample in train_dataset:
el_5 = sample[5]
print(el_5)

#########################################################################

the software here outputs 8 tweets as I am reading a csv file of tweets

@zellegatoc I buy dvds (Blu-ray) from the pirates of Metrowalk. Spotted Astro's buy-one-get-one-free offer on movies yesterday... vcds
TVMA is now on Twitter!!
@KirstieMAllsopp not nice when they grow up
Hv 2 b up in 4 & a half hrs so I'm gone! <3 u all! Tweet cha later!
I'm allergic to cats. what do I do? do not want to give up Cookie ever.
@kmcooley oh it's time doe me to get 1.6, I lost track. Heading there now.
Wot a braw day its turned out ti b
@LisaMantchev Very cool video. I watched and re-tweeted.

then it ggoes in error, here is the error trace:


ValueError Traceback (most recent call last)
Cell In[2], line 4
1 NumRows=10
2 train_dataset, test_dataset = datapipe.random_split(total_length=NumRows, weights={"train": 0.8, "test": 0.2 }, seed=42)
----> 4 for sample in train_dataset:
5 el_5 = sample[5]
6 print(el_5)

File ~/miniconda3/envs/pytorch/lib/python3.11/site-packages/torch/utils/data/datapipes/_hook_iterator.py:195, in hook_iterator..wrap_generator(*args, **kwargs)
193 else: # Decided against using contextlib.nullcontext for performance reasons
194 _check_iterator_valid(datapipe, iterator_id)
--> 195 response = gen.send(request)
196 except StopIteration as e:
197 return

File ~/miniconda3/envs/pytorch/lib/python3.11/site-packages/torchdata/datapipes/iter/util/randomsplitter.py:186, in SplitterIterator.iter(self)
184 self.main_datapipe.reset()
185 for sample in self.main_datapipe.source_datapipe:
--> 186 if self.main_datapipe.draw() == self.target:
187 yield sample

File ~/miniconda3/envs/pytorch/lib/python3.11/site-packages/torchdata/datapipes/iter/util/randomsplitter.py:105, in _RandomSplitterIterDataPipe.draw(self)
104 def draw(self) -> T:
--> 105 selected_key = self._rng.choices(self.keys, self.weights)[0]
106 index = self.key_to_index[selected_key]
107 self.weights[index] -= 1

File ~/miniconda3/envs/pytorch/lib/python3.11/random.py:509, in Random.choices(self, population, weights, cum_weights, k)
507 total = cum_weights[-1] + 0.0 # convert to float
508 if total <= 0.0:
--> 509 raise ValueError('Total of weights must be greater than zero')
510 if not _isfinite(total):
511 raise ValueError('Total of weights must be finite')

ValueError: Total of weights must be greater than zero
This exception is thrown by iter of SplitterIterator(main_datapipe=_RandomSplitterIterDataPipe, target='train')