DistributedBucketSampler division by zero problem
y1guo opened this issue · 1 comments
y1guo commented
In the DistributedBucketSampler
class, it's possible that the first bucket has no elements and thus results in a Division by zero
error.
class DistributedBucketSampler(torch.utils.data.distributed.DistributedSampler):
# ...
def _create_buckets(self):
buckets = [[] for _ in range(len(self.boundaries) - 1)]
for i in range(len(self.lengths)):
length = self.lengths[i]
idx_bucket = self._bisect(length)
if idx_bucket != -1:
buckets[idx_bucket].append(i)
# for i in range(len(buckets) - 1, 0, -1):
for i in range(len(buckets) - 1, -1, -1): # here: changed the stopping index to -1 to allow pop(0)
if len(buckets[i]) == 0:
buckets.pop(i)
self.boundaries.pop(i + 1)
num_samples_per_bucket = []
for i in range(len(buckets)):
len_bucket = len(buckets[i])
total_batch_size = self.num_replicas * self.batch_size
rem = (total_batch_size - (len_bucket % total_batch_size)) % total_batch_size
num_samples_per_bucket.append(len_bucket + rem)
return buckets, num_samples_per_bucket
# ...
Glad to discuss :)
daniilrobnikov commented
Thanks for pointing that out, I will research into that.
Actually, I had a idea to change the sampler technique for the future release