Spectrogram-based diffusion model
Closed this issue · 2 comments
Thanks for your contribution to this repository! I wonder if we can utilize this repository to develop a diffusion model based on spectrograms instead of waveforms. While implementing, I discovered that the UNetV0 has a dim=2
option that allows for the use of 2D-CNN in spectrograms. However, there seem to be some discrepancies in the hyperparameters of UNetV0
that lead to an error. It's a bit hard for me to debug since it heavily relies on a-unet
. Below I'll give more context.
Suppose we have a spectrogram whose shape is [1, 1, 256, 512] # [B, C, T, F]
.
Here's the model architecture I used:
return DiffusionModel(
net_t=UNetV0, # The model type used for diffusion (U-Net V0 in this case)
dim=2, # for spectrogram we use 2D-CNN
in_channels=1, # U-Net: number of input/output (audio) channels
channels=[8, 32, 64, 128, 256, 512, 512, 1024, 1024], # U-Net: channels at each layer
factors=[1, 4, 4, 4, 2, 2, 2, 2, 2], # U-Net: downsampling and upsampling factors at each layer
items=[1, 2, 2, 2, 2, 2, 2, 4, 4], # U-Net: number of repeating items at each layer
attentions=[0, 0, 0, 0, 0, 1, 1, 1, 1], # U-Net: attention enabled/disabled at each layer
attention_heads=8, # U-Net: number of attention heads per attention item
attention_features=64, # U-Net: number of attention features per attention item
diffusion_t=VDiffusion, # The diffusion method used
sampler_t=VSampler, # The diffusion sampler used
embedding_features=512, # U-Net: embedding features
cross_attentions=[0, 0, 0, 1, 1, 1, 1, 1, 1], # U-Net: cross-attention enabled/disabled at each layer
)
Here's the output error:
Traceback (most recent call last): File "/data/tinglok/texture/ldm.py", line 164, in <module> main() File "/data/tinglok/texture/ldm.py", line 105, in main loss = model(audio, embedding=cond_embed) File "/home/tinglok/miniconda3/envs/texture/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl return forward_call(*input, **kwargs) File "/home/tinglok/miniconda3/envs/texture/lib/python3.9/site-packages/audio_diffusion_pytorch/models.py", line 40, in forward return self.diffusion(*args, **kwargs) File "/home/tinglok/miniconda3/envs/texture/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl return forward_call(*input, **kwargs) File "/home/tinglok/miniconda3/envs/texture/lib/python3.9/site-packages/audio_diffusion_pytorch/diffusion.py", line 93, in forward v_pred = self.net(x_noisy, sigmas, **kwargs) File "/home/tinglok/miniconda3/envs/texture/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl return forward_call(*input, **kwargs) File "/home/tinglok/miniconda3/envs/texture/lib/python3.9/site-packages/a_unet/blocks.py", line 63, in forward return forward_fn(*args, **kwargs) File "/home/tinglok/miniconda3/envs/texture/lib/python3.9/site-packages/a_unet/blocks.py", line 594, in forward
return net(x, features=features, **kwargs)
File "/home/tinglok/miniconda3/envs/texture/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
return forward_call(*input, **kwargs)
File "/home/tinglok/miniconda3/envs/texture/lib/python3.9/site-packages/a_unet/apex.py", line 431, in forward
return self.net(x, features, embedding, channels) # type: ignore
File "/home/tinglok/miniconda3/envs/texture/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
return forward_call(*input, **kwargs)
File "/home/tinglok/miniconda3/envs/texture/lib/python3.9/site-packages/a_unet/apex.py", line 382, in forward
x = self.block(x, features, embedding, channels)
File "/home/tinglok/miniconda3/envs/texture/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
return forward_call(*input, **kwargs)
File "/home/tinglok/miniconda3/envs/texture/lib/python3.9/site-packages/a_unet/blocks.py", line 77, in forward x = block(x, *args) File "/home/tinglok/miniconda3/envs/texture/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl return forward_call(*input, **kwargs) File "/home/tinglok/miniconda3/envs/texture/lib/python3.9/site-packages/a_unet/apex.py", line 382, in forward x = self.block(x, features, embedding, channels) File "/home/tinglok/miniconda3/envs/texture/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl return forward_call(*input, **kwargs) File "/home/tinglok/miniconda3/envs/texture/lib/python3.9/site-packages/a_unet/blocks.py", line 77, in forward x = block(x, *args) File "/home/tinglok/miniconda3/envs/texture/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl return forward_call(*input, **kwargs) File "/home/tinglok/miniconda3/envs/texture/lib/python3.9/site-packages/a_unet/apex.py", line 382, in forward x = self.block(x, features, embedding, channels) File "/home/tinglok/miniconda3/envs/texture/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl return forward_call(*input, **kwargs) File "/home/tinglok/miniconda3/envs/texture/lib/python3.9/site-packages/a_unet/blocks.py", line 77, in forward x = block(x, *args) File "/home/tinglok/miniconda3/envs/texture/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl return forward_call(*input, **kwargs) File "/home/tinglok/miniconda3/envs/texture/lib/python3.9/site-packages/a_unet/apex.py", line 382, in forward x = self.block(x, features, embedding, channels) File "/home/tinglok/miniconda3/envs/texture/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl return forward_call(*input, **kwargs) File "/home/tinglok/miniconda3/envs/texture/lib/python3.9/site-packages/a_unet/blocks.py", line 77, in forward x = block(x, *args) File "/home/tinglok/miniconda3/envs/texture/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl return forward_call(*input, **kwargs) File "/home/tinglok/miniconda3/envs/texture/lib/python3.9/site-packages/a_unet/apex.py", line 382, in forward x = self.block(x, features, embedding, channels) File "/home/tinglok/miniconda3/envs/texture/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl return forward_call(*input, **kwargs) File "/home/tinglok/miniconda3/envs/texture/lib/python3.9/site-packages/a_unet/blocks.py", line 77, in forward x = block(x, *args)
File "/home/tinglok/miniconda3/envs/texture/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
return forward_call(*input, **kwargs)
File "/home/tinglok/miniconda3/envs/texture/lib/python3.9/site-packages/a_unet/apex.py", line 382, in forward
x = self.block(x, features, embedding, channels)
File "/home/tinglok/miniconda3/envs/texture/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
return forward_call(*input, **kwargs)
File "/home/tinglok/miniconda3/envs/texture/lib/python3.9/site-packages/a_unet/blocks.py", line 77, in forward
x = block(x, *args)
File "/home/tinglok/miniconda3/envs/texture/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
return forward_call(*input, **kwargs)
File "/home/tinglok/miniconda3/envs/texture/lib/python3.9/site-packages/a_unet/apex.py", line 382, in forward
x = self.block(x, features, embedding, channels)
File "/home/tinglok/miniconda3/envs/texture/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
return forward_call(*input, **kwargs)
File "/home/tinglok/miniconda3/envs/texture/lib/python3.9/site-packages/a_unet/blocks.py", line 77, in forward
x = block(x, *args)
File "/home/tinglok/miniconda3/envs/texture/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
return forward_call(*input, **kwargs)
File "/home/tinglok/miniconda3/envs/texture/lib/python3.9/site-packages/a_unet/blocks.py", line 92, in forward
return self.block(*args_fn(*args), **kwargs)
File "/home/tinglok/miniconda3/envs/texture/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
return forward_call(*input, **kwargs)
File "/home/tinglok/miniconda3/envs/texture/lib/python3.9/site-packages/torch/nn/modules/conv.py", line 463, in forward
return self._conv_forward(input, self.weight, self.bias)
File "/home/tinglok/miniconda3/envs/texture/lib/python3.9/site-packages/torch/nn/modules/conv.py", line 459, in _conv_forward
return F.conv2d(input, weight, bias, self.stride,
RuntimeError: Calculated padded input size per channel: (1 x 1). Kernel size: (2 x 2). Kernel size can't be greater than actual input size
Can we resolve this error by disabling some downsampling of UNetV0
?
Yeah I think the problem is the the downsampling factors [1, 4, 4, 4, 2, 2, 2, 2, 2] that multiply to 2048, hence if you have one dimension of size 256 it will be downsampled to 256//2048 which is < 1. For spectrogram I'd use only 2x downsampling and a much shallower net, something like:
net_t=UNetV0, # The model type used for diffusion (U-Net V0 in this case)
dim=2, # for spectrogram we use 2D-CNN
in_channels=1, # U-Net: number of input/output (audio) channels
channels=[32, 64, 128, 256], # U-Net: channels at each layer
factors=[1, 2, 2, 2], # U-Net: downsampling and upsampling factors at each layer
items=[2, 2, 2, 4], # U-Net: number of repeating items at each layer
attentions=[0, 0, 0, 1], # U-Net: attention enabled/disabled at each layer
attention_heads=8, # U-Net: number of attention heads per attention item
attention_features=64, # U-Net: number of attention features per attention item
diffusion_t=VDiffusion, # The diffusion method used
sampler_t=VSampler, # The diffusion sampler used
embedding_features=512, # U-Net: embedding features
cross_attentions=[0, 0, 0, 1], # U-Net: cross-attention enabled/disabled at each layer
Note that you would have to try different variations of channels
, factors
, items
, attentions
, cross_attentions
to see what works best
Thanks for your reply! I'll check it out.