archinetai/audio-diffusion-pytorch

Exploding loss

Closed this issue · 3 comments

The loss suddenly increases from <0.1 to billions over one or two epochs.

I'm training an AudioDiffusionModel and I've had happen with both the default diffusion_type='v' as well as with diffusion_type='vk', also, it happens both with and without gradient clipping. It's happened with several datasets and different batch sizes (the output below is a particularly small dataset with a large batch size)

It seems to happen more often, the closer it gets to 0 loss.

Output:

1328 Loss : 0.0562
100% 6/6 [00:01<00:00,  3.93it/s]
1329 Loss : 0.0517
100% 6/6 [00:01<00:00,  3.95it/s]
1330 Loss : 0.0500
100% 6/6 [00:01<00:00,  3.95it/s]
1331 Loss : 0.0374
100% 6/6 [00:01<00:00,  3.93it/s]
1332 Loss : 0.0519
100% 6/6 [00:01<00:00,  3.69it/s]
1333 Loss : 0.0557
100% 6/6 [00:01<00:00,  3.47it/s]
1334 Loss : 0.0499
100% 6/6 [00:01<00:00,  3.33it/s]
1335 Loss : 0.0482
100% 6/6 [00:01<00:00,  3.74it/s]
1336 Loss : 1.4608
100% 6/6 [00:01<00:00,  3.89it/s]
1337 Loss : 35551447.3009
100% 6/6 [00:01<00:00,  3.91it/s]
1338 Loss : 17436217794.0833
100% 6/6 [00:01<00:00,  3.86it/s]
1339 Loss : 15120838197.3333
100% 6/6 [00:01<00:00,  3.88it/s]
1340 Loss : 1137136360.0000
100% 6/6 [00:01<00:00,  3.83it/s]
1341 Loss : 184102040.6667
100% 6/6 [00:01<00:00,  3.80it/s]
1342 Loss : 24171988.5000
100% 6/6 [00:01<00:00,  3.85it/s]
1343 Loss : 100907.1549
100% 6/6 [00:01<00:00,  3.80it/s]
1344 Loss : 10494.4541
100% 6/6 [00:01<00:00,  3.83it/s]
1345 Loss : 989.2273

The model:

class DiffModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = AudioDiffusionModel(in_channels=1, diffusion_type='vk', diffusion_sigma_distribution=VKDistribution())
        self.optimizer = torch.optim.AdamW(list(self.model.parameters()))

    def train(self, x):
        self.optimizer.zero_grad()

        loss = self.model(x)
        loss.backward()

        clip_grad_norm_(self.model.parameters(), 1.)

        self.optimizer.step()

        return loss.item()
    ...

Training:

for epoch in range(load_epoch + 1, MAX_EPOCHS):
    acc_loss = 0
    for x in tqdm(dataloader):
        x = x.to(device)
        acc_loss += model.train(x)
    loss = acc_loss / epoch_steps
    print(f'{epoch} Loss : {loss:.4f}')
    ...

I didn't thoroughly test vk diffusion (usually I go with v) but also never had exploding problems. Check that your dataset is distributed equally, i.e. if there are suddenly in a batch multiple silent samples that can mess up the model. For that, you might want to use the WAVDataset in https://github.com/archinetai/audio-data-pytorch with the check_silence set to true, or do some similar checks.

I am actually, using WAVDataset with check_silence set to True - which is the default, also, the datasets I'm using are taken from one-shot sound packs and another was a set of wavetables, so that's not the problem apparently...
Thanks for the help btw! :)

Turns out this was a good-old "too high a learn-rate" problem...
I was using default optimizer settings for lr, betas, eps and weight_decay
Using the base configuration on audio-diffusion-pytorch-trainer solved this issue:

self.optimizer = torch.optim.AdamW(
    params = list(self.model.parameters()),
    lr = 1e-4,
    betas= (0.95, 0.999),
    eps= 1e-6,
    weight_decay= 1e-3)