archinetai/audio-diffusion-pytorch

Spectrogram-based diffusion model

Closed this issue · 2 comments

Thanks for your contribution to this repository! I wonder if we can utilize this repository to develop a diffusion model based on spectrograms instead of waveforms. While implementing, I discovered that the UNetV0 has a dim=2 option that allows for the use of 2D-CNN in spectrograms. However, there seem to be some discrepancies in the hyperparameters of UNetV0 that lead to an error. It's a bit hard for me to debug since it heavily relies on  a-unet. Below I'll give more context.

Suppose we have a spectrogram whose shape is [1, 1, 256, 512] # [B, C, T, F].

Here's the model architecture I used:

return DiffusionModel(
        net_t=UNetV0, # The model type used for diffusion (U-Net V0 in this case)
        dim=2, # for spectrogram we use 2D-CNN
        in_channels=1, # U-Net: number of input/output (audio) channels
        channels=[8, 32, 64, 128, 256, 512, 512, 1024, 1024], # U-Net: channels at each layer
        factors=[1, 4, 4, 4, 2, 2, 2, 2, 2], # U-Net: downsampling and upsampling factors at each layer
        items=[1, 2, 2, 2, 2, 2, 2, 4, 4], # U-Net: number of repeating items at each layer
        attentions=[0, 0, 0, 0, 0, 1, 1, 1, 1], # U-Net: attention enabled/disabled at each layer
        attention_heads=8, # U-Net: number of attention heads per attention item
        attention_features=64, # U-Net: number of attention features per attention item
        diffusion_t=VDiffusion, # The diffusion method used
        sampler_t=VSampler, # The diffusion sampler used
        embedding_features=512, # U-Net: embedding features
        cross_attentions=[0, 0, 0, 1, 1, 1, 1, 1, 1], # U-Net: cross-attention enabled/disabled at each layer 
    )

Here's the output error:

Traceback (most recent call last):                                                                                                                                                                                                                                                                       File "/data/tinglok/texture/ldm.py", line 164, in <module>                                                                                                                                                                                                                                               main()                                                                                                                                                                                                                                                                                               File "/data/tinglok/texture/ldm.py", line 105, in main                                                                                                                                                                                                                                                   loss = model(audio, embedding=cond_embed)                                                                                                                                                                                                                                                            File "/home/tinglok/miniconda3/envs/texture/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl                                                                                                                                                                            return forward_call(*input, **kwargs)                                                                                                                                                                                                                                                                File "/home/tinglok/miniconda3/envs/texture/lib/python3.9/site-packages/audio_diffusion_pytorch/models.py", line 40, in forward                                                                                                                                                                          return self.diffusion(*args, **kwargs)                                                                                                                                                                                                                                                               File "/home/tinglok/miniconda3/envs/texture/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl                                                                                                                                                                            return forward_call(*input, **kwargs)                                                                                                                                                                                                                                                                File "/home/tinglok/miniconda3/envs/texture/lib/python3.9/site-packages/audio_diffusion_pytorch/diffusion.py", line 93, in forward                                                                                                                                                                       v_pred = self.net(x_noisy, sigmas, **kwargs)                                                                                                                                                                                                                                                         File "/home/tinglok/miniconda3/envs/texture/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl                                                                                                                                                                            return forward_call(*input, **kwargs)                                                                                                                                                                                                                                                                File "/home/tinglok/miniconda3/envs/texture/lib/python3.9/site-packages/a_unet/blocks.py", line 63, in forward                                                                                                                                                                                           return forward_fn(*args, **kwargs)                                                                                                                                                                                                                                                                   File "/home/tinglok/miniconda3/envs/texture/lib/python3.9/site-packages/a_unet/blocks.py", line 594, in forward                                                                                                                                                                                      
    return net(x, features=features, **kwargs)                                                                                                                                                                                                                                                         
  File "/home/tinglok/miniconda3/envs/texture/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl                                                                                                                                                                        
    return forward_call(*input, **kwargs)                                                                                                          
  File "/home/tinglok/miniconda3/envs/texture/lib/python3.9/site-packages/a_unet/apex.py", line 431, in forward                                                                                                                                                                                        
    return self.net(x, features, embedding, channels)  # type: ignore                                                                              
  File "/home/tinglok/miniconda3/envs/texture/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl                                                                                                                                                                        
    return forward_call(*input, **kwargs)                                                                                                                                                                                                                                                              
  File "/home/tinglok/miniconda3/envs/texture/lib/python3.9/site-packages/a_unet/apex.py", line 382, in forward                                                                                                                                                                                        
    x = self.block(x, features, embedding, channels)                                                                                                                                                                                                                                                   
  File "/home/tinglok/miniconda3/envs/texture/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl                                                                                                                                                                        
    return forward_call(*input, **kwargs)
  File "/home/tinglok/miniconda3/envs/texture/lib/python3.9/site-packages/a_unet/blocks.py", line 77, in forward                                                                                                                                                                                           x = block(x, *args)                                                                                                                                                                                                                                                                                  File "/home/tinglok/miniconda3/envs/texture/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl                                                                                                                                                                            return forward_call(*input, **kwargs)                                                                                                                                                                                                                                                                File "/home/tinglok/miniconda3/envs/texture/lib/python3.9/site-packages/a_unet/apex.py", line 382, in forward                                                                                                                                                                                            x = self.block(x, features, embedding, channels)                                                                                                                                                                                                                                                     File "/home/tinglok/miniconda3/envs/texture/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl                                                                                                                                                                            return forward_call(*input, **kwargs)                                                                                                                                                                                                                                                                File "/home/tinglok/miniconda3/envs/texture/lib/python3.9/site-packages/a_unet/blocks.py", line 77, in forward                                                                                                                                                                                           x = block(x, *args)                                                                                                                                                                                                                                                                                  File "/home/tinglok/miniconda3/envs/texture/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl                                                                                                                                                                            return forward_call(*input, **kwargs)                                                                                                                                                                                                                                                                File "/home/tinglok/miniconda3/envs/texture/lib/python3.9/site-packages/a_unet/apex.py", line 382, in forward                                                                                                                                                                                            x = self.block(x, features, embedding, channels)                                                                                                                                                                                                                                                     File "/home/tinglok/miniconda3/envs/texture/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl                                                                                                                                                                            return forward_call(*input, **kwargs)                                                                                                                                                                                                                                                                File "/home/tinglok/miniconda3/envs/texture/lib/python3.9/site-packages/a_unet/blocks.py", line 77, in forward                                                                                                                                                                                           x = block(x, *args)                                                                                                                                                                                                                                                                                  File "/home/tinglok/miniconda3/envs/texture/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl                                                                                                                                                                            return forward_call(*input, **kwargs)                                                                                                                                                                                                                                                                File "/home/tinglok/miniconda3/envs/texture/lib/python3.9/site-packages/a_unet/apex.py", line 382, in forward                                                                                                                                                                                            x = self.block(x, features, embedding, channels)                                                                                                                                                                                                                                                     File "/home/tinglok/miniconda3/envs/texture/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl                                                                                                                                                                            return forward_call(*input, **kwargs)                                                                                                                                                                                                                                                                File "/home/tinglok/miniconda3/envs/texture/lib/python3.9/site-packages/a_unet/blocks.py", line 77, in forward                                                                                                                                                                                           x = block(x, *args)                                                                                                                                                                                                                                                                                  File "/home/tinglok/miniconda3/envs/texture/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl                                                                                                                                                                            return forward_call(*input, **kwargs)                                                                                                                                                                                                                                                                File "/home/tinglok/miniconda3/envs/texture/lib/python3.9/site-packages/a_unet/apex.py", line 382, in forward                                                                                                                                                                                            x = self.block(x, features, embedding, channels)                                                                                                                                                                                                                                                     File "/home/tinglok/miniconda3/envs/texture/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl                                                                                                                                                                            return forward_call(*input, **kwargs)                                                                                                                                                                                                                                                                File "/home/tinglok/miniconda3/envs/texture/lib/python3.9/site-packages/a_unet/blocks.py", line 77, in forward                                                                                                                                                                                           x = block(x, *args)                                                                                                                                                                                                                                                                                
  File "/home/tinglok/miniconda3/envs/texture/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/tinglok/miniconda3/envs/texture/lib/python3.9/site-packages/a_unet/apex.py", line 382, in forward
    x = self.block(x, features, embedding, channels)
  File "/home/tinglok/miniconda3/envs/texture/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/tinglok/miniconda3/envs/texture/lib/python3.9/site-packages/a_unet/blocks.py", line 77, in forward
    x = block(x, *args)
  File "/home/tinglok/miniconda3/envs/texture/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/tinglok/miniconda3/envs/texture/lib/python3.9/site-packages/a_unet/apex.py", line 382, in forward
    x = self.block(x, features, embedding, channels)
  File "/home/tinglok/miniconda3/envs/texture/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/tinglok/miniconda3/envs/texture/lib/python3.9/site-packages/a_unet/blocks.py", line 77, in forward
    x = block(x, *args)
  File "/home/tinglok/miniconda3/envs/texture/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/tinglok/miniconda3/envs/texture/lib/python3.9/site-packages/a_unet/blocks.py", line 92, in forward
    return self.block(*args_fn(*args), **kwargs)
  File "/home/tinglok/miniconda3/envs/texture/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/tinglok/miniconda3/envs/texture/lib/python3.9/site-packages/torch/nn/modules/conv.py", line 463, in forward
    return self._conv_forward(input, self.weight, self.bias)
  File "/home/tinglok/miniconda3/envs/texture/lib/python3.9/site-packages/torch/nn/modules/conv.py", line 459, in _conv_forward
    return F.conv2d(input, weight, bias, self.stride,
RuntimeError: Calculated padded input size per channel: (1 x 1). Kernel size: (2 x 2). Kernel size can't be greater than actual input size

Can we resolve this error by disabling some downsampling of UNetV0?

Yeah I think the problem is the the downsampling factors [1, 4, 4, 4, 2, 2, 2, 2, 2] that multiply to 2048, hence if you have one dimension of size 256 it will be downsampled to 256//2048 which is < 1. For spectrogram I'd use only 2x downsampling and a much shallower net, something like:

net_t=UNetV0, # The model type used for diffusion (U-Net V0 in this case)
dim=2, # for spectrogram we use 2D-CNN
in_channels=1, # U-Net: number of input/output (audio) channels
channels=[32, 64, 128, 256], # U-Net: channels at each layer
factors=[1, 2, 2, 2], # U-Net: downsampling and upsampling factors at each layer
items=[2, 2, 2, 4], # U-Net: number of repeating items at each layer
attentions=[0, 0, 0, 1], # U-Net: attention enabled/disabled at each layer
attention_heads=8, # U-Net: number of attention heads per attention item
attention_features=64, # U-Net: number of attention features per attention item
diffusion_t=VDiffusion, # The diffusion method used
sampler_t=VSampler, # The diffusion sampler used
embedding_features=512, # U-Net: embedding features
cross_attentions=[0, 0, 0, 1], # U-Net: cross-attention enabled/disabled at each layer 

Note that you would have to try different variations of channels, factors, items, attentions, cross_attentions to see what works best

Thanks for your reply! I'll check it out.