daniilrobnikov/vits2

DistributedBucketSampler division by zero problem

y1guo opened this issue · 1 comments

y1guo commented

In the DistributedBucketSampler class, it's possible that the first bucket has no elements and thus results in a Division by zero error.

class DistributedBucketSampler(torch.utils.data.distributed.DistributedSampler):

    # ...

    def _create_buckets(self):
        buckets = [[] for _ in range(len(self.boundaries) - 1)]
        for i in range(len(self.lengths)):
            length = self.lengths[i]
            idx_bucket = self._bisect(length)
            if idx_bucket != -1:
                buckets[idx_bucket].append(i)

        # for i in range(len(buckets) - 1, 0, -1): 
        for i in range(len(buckets) - 1, -1, -1):  # here: changed the stopping index to -1 to allow pop(0)
            if len(buckets[i]) == 0:
                buckets.pop(i)
                self.boundaries.pop(i + 1)

        num_samples_per_bucket = []
        for i in range(len(buckets)):
            len_bucket = len(buckets[i])
            total_batch_size = self.num_replicas * self.batch_size
            rem = (total_batch_size - (len_bucket % total_batch_size)) % total_batch_size
            num_samples_per_bucket.append(len_bucket + rem)
        return buckets, num_samples_per_bucket

    # ...

Glad to discuss :)

Thanks for pointing that out, I will research into that.

Actually, I had a idea to change the sampler technique for the future release