/dualdiffusion

Fourier Dual Diffusion

Primary LanguagePythonMIT LicenseMIT

Dual Diffusion

Dual Diffusion is a generative diffusion model for music. This model and the code in this repository is still a work-in-progress.

I'm currently using music from SNES games as my dataset during development. The dataset is comprised of ~20,000 samples of lengths between 1 and 3 minutes. I am using the game the music is from as a class label for conditioning, which means you can choose a game (or a weighted combination of multiple games) to generate new music with the appropriate instruments and style. The number of examples per class / game is anywhere from ~5 to ~50 and all generated samples combining more than 1 game are "zero-shot".

You can hear some samples generated by the model in various stages of training / development here.

I started this project in August/2023 with the intention of achieving 3 goals:

  • Familiarize myself with every component of modern diffusion models in both inference and training
  • Train a model from scratch that is able to generate music that I would actually want to listen to
  • Do the above using only a single consumer GPU (4090)

The model has changed substantially over the course of development in the last 12 months.

  • Initially (August/2023) the diffusion model was unconditional and worked directly on raw audio.

    • Due to memory and performance constraints this meant I was limited to ~16 seconds of mono audio at 8khz.
    • I experimented with 1d and 2d formats with various preprocessing steps. I found 2d formats were able to generalize better with a small dataset and were more compute and parameter efficient than 1d formats.
    • For 2d formats I found using separable attention (merging the rows / columns with the batch dimension alternately) could make using attention in a high dimensionality model practical without sacrificing too much quality.
    • I found that attention was an absolute requirement to allow the model to understand the perceptual ~log-frequency scale in music when the format uses a linear-frequency scale, especially with positional embeddings for the frequency axis.
    • I found that v-prediction with a cosine-based schedule worked significantly better than any alternatives (for raw audio / non-latent diffusion).
  • In December/2023 I began training variational auto-encoder models to try a latent diffusion model.

    • After moving to latent diffusion I was able to begin training the diffusion model on crops of 45 seconds @ 32khz stereo
    • I found that point-wise loss works very poorly in terms of reconstruction quality, and multi-scale spectral loss was a much better option. I found that multi-scale spectral power density loss works considerably better for music and learning pitch if you add an appropriately weighted loss term for phase.
    • I found that although it was possible for a VAE that includes phase to have good reconstruction quality and a compact latent space, the latent space had high entropy and was not easily interpretable by the latent diffusion model. This resulted in in a model that could generate good quality sounds but had poor musicality.
    • I found that the latent diffusion model performs substantially better with input latents that have uniform variance; there is no need for the log-variance to be a learnable parameter, instead I predefine a uniform target SNR for the latent distribution.
    • I found that the latent diffusion model performance is substantially improved with lower target snr / higher log-variance in the latent distribution, provided the latent distribution mode is used for the training target. Lower target snr allows the latent diffusion model to exclude low noise levels during both training and inference; the residual noise after sampling can be designed to match the expected variance in the latent distribution.
  • In March/2024 I started using mel-scale spectrogram based formats, excluding phase information.

    • I found that it was possible to considerably improve the FGLA phase reconstruction by tuning the window and spectrogram parameters, as well as modifying the algorithm to anneal phases for stereo signals in a more coherent way. I settled on a set of parameters that resulted in a spectrogram dimensionality that is the same as the critically sampled raw audio without sacrificing too much perceptual quality.
    • I found that multi-scale spectral loss works well 2d for spectrogram / image data, the resulting quality is somewhere between point-wise loss and adversarial loss.
  • In April/2024 I replaced the diffusion model and VAE unets (previously based on the unconditional unet in diffusers) with the improved edm2 unet.

    • I made several improvements to the edm2 unet:
      • Replaced einsum attention with torch sdp attention
      • Replaced fourier embedding frequencies / phases for a smoother inner product space
      • Added class label dropout in a way that preserves magnitude on expectation
      • Replaced the weight normalization in the forward method of the mpconv module with weight normalization that is only applied when the weights are updated for improved performance and lower memory consumption
      • Added a correction when using dropout inside blocks to preserve magnitude on expectation during inference
      • Replaced the up/downsample with equivalent torch built-ins for improved performance
      • Merged some of the pre-conditioning code into the unet itself.
    • I started using torch dynamo / model compilation and added the appropriate compiler hints for maximum performance.
    • I started using class label-based conditioning and implemented classifier free guidance for a major improvement in quality and control.
  • In May/2024 I adopted the edm/ddim noise schedule, sampling algorithm, and learn rate schedule.

    • I found that the log-normal sigma sampling in training could be improved by using the per-sigma estimated error variance to calculate a sigma sampling pdf that explicitly focuses on the noise levels that the model is most capable of making progress on.
    • I found that using stratified sampling to distribute sigmas as evenly as possible within each mini-batch could mitigate problems with smaller mini-batches.
    • I began pre-encoding the latents for my dataset before diffusion model training for increased performance and reduced memory consumption. I found pre-encoding the latents before random crop can negatively influence generated sample quality due to the lack of variations created by sub-latent-pixel offsets. I added pre-encoded latent variations for those offsets.
    • I began training with EMA weights of multiple lengths.
  • In June-July/2024 I experimented with the model architecture and added noise in sampling.

    • I found that using low-rank linear layers in the model resnet blocks could significantly increase quality without drastically increasing the number of parameters; Specifically, projecting to a higher number of dimensions inside the resnet block where the non-linearity is applied and then projecting back down.
    • I found that adding conditioned modulation to the self-attention qkv linear layers significantly increased conditioning quality.
    • I trained another 1d model using the all the new architecture improvements and confirmed my earlier findings that 1d models generalize poorly with a small dataset.
    • Inspired by this paper I experimented with adding a large amount of noise to prolong the number of steps where the model is in a critical or near critical state during sampling. This significantly increased the quality of the samples while drastically lowering the resulting sample temperature which is desirable for music.
  • In August/2024 I began a near-complete rewrite of the codebase to pay down the technical debt that had accumulated over the last year of rapid experimentation.

    • I re-implemented the last of the remaining code that used diffusers and removed the dependency.
    • I built a modular system to allow for easy experimentation with the training process for arbitrary modules.
    • The training code was re-writen and uses configuration files as much as possible to allow for automated hyperparameter or model architecture search.
    • Dataset pre-processing code was completely re-written. I intend to spend more time developing tools / models to filter and augment the dataset which will allow me to use a larger volume of low quality data.

Some additional notes:

  • The training code supports multiple GPUs and distributed training through huggingface accelerate, currently logging is to tensorboard.
  • The dataset pre/post-processing code is included in this repository which includes everything needed to train a new model on your own data
  • All the code in this repository is tested to work on both Windows and Linux platforms, although performance is significantly better on Linux