Exploding loss
Closed this issue · 3 comments
The loss suddenly increases from <0.1 to billions over one or two epochs.
I'm training an AudioDiffusionModel
and I've had happen with both the default diffusion_type='v'
as well as with diffusion_type='vk'
, also, it happens both with and without gradient clipping. It's happened with several datasets and different batch sizes (the output below is a particularly small dataset with a large batch size)
It seems to happen more often, the closer it gets to 0 loss.
Output:
1328 Loss : 0.0562
100% 6/6 [00:01<00:00, 3.93it/s]
1329 Loss : 0.0517
100% 6/6 [00:01<00:00, 3.95it/s]
1330 Loss : 0.0500
100% 6/6 [00:01<00:00, 3.95it/s]
1331 Loss : 0.0374
100% 6/6 [00:01<00:00, 3.93it/s]
1332 Loss : 0.0519
100% 6/6 [00:01<00:00, 3.69it/s]
1333 Loss : 0.0557
100% 6/6 [00:01<00:00, 3.47it/s]
1334 Loss : 0.0499
100% 6/6 [00:01<00:00, 3.33it/s]
1335 Loss : 0.0482
100% 6/6 [00:01<00:00, 3.74it/s]
1336 Loss : 1.4608
100% 6/6 [00:01<00:00, 3.89it/s]
1337 Loss : 35551447.3009
100% 6/6 [00:01<00:00, 3.91it/s]
1338 Loss : 17436217794.0833
100% 6/6 [00:01<00:00, 3.86it/s]
1339 Loss : 15120838197.3333
100% 6/6 [00:01<00:00, 3.88it/s]
1340 Loss : 1137136360.0000
100% 6/6 [00:01<00:00, 3.83it/s]
1341 Loss : 184102040.6667
100% 6/6 [00:01<00:00, 3.80it/s]
1342 Loss : 24171988.5000
100% 6/6 [00:01<00:00, 3.85it/s]
1343 Loss : 100907.1549
100% 6/6 [00:01<00:00, 3.80it/s]
1344 Loss : 10494.4541
100% 6/6 [00:01<00:00, 3.83it/s]
1345 Loss : 989.2273
The model:
class DiffModel(nn.Module):
def __init__(self):
super().__init__()
self.model = AudioDiffusionModel(in_channels=1, diffusion_type='vk', diffusion_sigma_distribution=VKDistribution())
self.optimizer = torch.optim.AdamW(list(self.model.parameters()))
def train(self, x):
self.optimizer.zero_grad()
loss = self.model(x)
loss.backward()
clip_grad_norm_(self.model.parameters(), 1.)
self.optimizer.step()
return loss.item()
...
Training:
for epoch in range(load_epoch + 1, MAX_EPOCHS):
acc_loss = 0
for x in tqdm(dataloader):
x = x.to(device)
acc_loss += model.train(x)
loss = acc_loss / epoch_steps
print(f'{epoch} Loss : {loss:.4f}')
...
I didn't thoroughly test vk diffusion (usually I go with v) but also never had exploding problems. Check that your dataset is distributed equally, i.e. if there are suddenly in a batch multiple silent samples that can mess up the model. For that, you might want to use the WAVDataset in https://github.com/archinetai/audio-data-pytorch with the check_silence
set to true
, or do some similar checks.
I am actually, using WAVDataset with check_silence
set to True
- which is the default, also, the datasets I'm using are taken from one-shot sound packs and another was a set of wavetables, so that's not the problem apparently...
Thanks for the help btw! :)
Turns out this was a good-old "too high a learn-rate" problem...
I was using default optimizer settings for lr
, betas
, eps
and weight_decay
Using the base configuration on audio-diffusion-pytorch-trainer solved this issue:
self.optimizer = torch.optim.AdamW(
params = list(self.model.parameters()),
lr = 1e-4,
betas= (0.95, 0.999),
eps= 1e-6,
weight_decay= 1e-3)