lucidrains/BS-RoFormer

Training of BS-RoFormer

ZFTurbo opened this issue ยท 36 comments

I tried to train this neural net without any success. SDR stuck in around 2.1 for vocals and never grows more. If somebody have better results please let me know.

@ZFTurbo did you follow the band splitting hyperparameters as in the paper?

@ZFTurbo did you follow the band splitting hyperparameters as in the paper?

Unfortunately, no. I don't really understand how authors do band split. Also it must be done inside the NN code, which I left unchanged.

In standard band split method it is made very easy. Large plain is split into several planes of the same size by frequency. And you have like several image channels for this. (4096, 512) -> (8, 512, 512). Each channel represents frequencies of some range.

From paper.

We use the following band-split scheme: 2 bins per band for frequencies under 1000 Hz, 4 bins per band between 1000 Hz and 2000 Hz, 12 bins per band between 2000 Hz and 4000 Hz, 24 bins per band between 4000 Hz and 8000 Hz, 48 bins per band between 8000 Hz and 16000 Hz, and the remaining bins beyond 16000 Hz equally divided into two bands. This results in a total of 62 bands. All bands are non-overlapping.

I don't understand it in terms of tensors.

I think they use the same method as in BSRNN neural net for BandSplit:

https://github.com/sungwon23/BSRNN/blob/main/module.py#L73

yea, I'll get around to setting the proper band frequency hyperparameters and you can give it another go

I'm growing interested in this technique, as I think it can be applied for medical segmentation

From paper.

We use the following band-split scheme: 2 bins per band for frequencies under 1000 Hz, 4 bins per band between 1000 Hz and 2000 Hz, 12 bins per band between 2000 Hz and 4000 Hz, 24 bins per band between 4000 Hz and 8000 Hz, 48 bins per band between 8000 Hz and 16000 Hz, and the remaining bins beyond 16000 Hz equally divided into two bands. This results in a total of 62 bands. All bands are non-overlapping.

I don't understand it in terms of tensors.

yes you would need to set it according to this. I was also planning on adding the overlapping frequency bands they mentioned in the section of stuff they wanted to try next

yes, it is all built.

you need to have the correct tuple of 62 integer here https://github.com/lucidrains/BS-RoFormer/blob/main/bs_roformer/bs_roformer.py#L222

So I need to make tuple which consists of 62 integers? What these integers mean? Is it ranges in Hz?

For example first 6 if follow the paper: (500, 1000, 1250, 1500, 1750, 2000, ...)?

UPD: I think it's no in Hz. It's 62 numbers from 0 up to (FFT size / 2).

yes, I believe it is exactly ranges of freqs in order of low to high frequencies. however I have not sat down and worked out that section of the paper yet

Also: 2 + 4 + 12 + 24 + 48 + N > 62 May be I don't understand what 62 is.

haha, I honestly haven't sat down and worked it out myself yet. but it is ok since I think the big idea is just uneven splits across frequencies and project to tokens with own MLP

you should just try a naive even split of frequencies by 64 and run it again, before going for the precise breakdown

at the moment, you are doing attention of 2 tokens across the frequency axis, which does basically nothing

I'm trying this now:

model = BSRoformer(
        stereo=True,
        dim=256,
        depth=12,
        time_transformer_depth=2,
        freq_transformer_depth=2,
        freqs_per_bands=(
            32, 32, 32, 32, 32, 32, 32, 32,
            32, 32, 32, 32, 32, 32, 32, 32,
            32, 32, 32, 32, 32, 32, 32, 32,
            32, 32, 32, 32, 32, 32, 32, 33,
        )
    )

nice! yeah, that should work better

Hi, @ZFTurbo ,Do you have the results yet?

According to what is written in the paper,

"According to our observations, the training progress of BS-Transformer is very slow, and it still remains low SDRs after two weeks of training on Musdb18HQ. Instead, BS-RoFormer models with L=6 get converged within a week."

and

"Models with L=6 are trained solely on the Musdb18HQ training set using 16 Nvidia V100-32GB GPUs."

This model likely requires many GPUs to train for several weeks.

I'm currently training:

 model = BSRoformer(
        stereo=True,
        dim=128,
        depth=12,
        time_transformer_depth=1,
        freq_transformer_depth=1,
        freqs_per_bands=(
            32, 32, 32, 32, 32, 32, 32, 32,
            32, 32, 32, 32, 32, 32, 32, 32,
            32, 32, 32, 32, 32, 32, 32, 32,
            32, 32, 32, 32, 32, 32, 32, 33,
        )
    )

It's training very very slow. I use batch size 12. One epoch total 1000 batches requires 1 hour 40 minutes to finish. But loss is constantly decreasing now. From logs:

Training loss: 3.939826 SDR vocals: 1.7817
Training loss: 2.848435 SDR vocals: 1.4592
Training loss: 2.587866 SDR vocals: 1.5311
Training loss: 2.503295 SDR vocals: 1.3758
Training loss: 2.497599 SDR vocals: 1.9409
Training loss: 2.438752 SDR vocals: 1.8686
Training loss: 2.427704 SDR vocals: 1.4726
Training loss: 2.412679 SDR vocals: 1.4579
Training loss: 2.384547 SDR vocals: 1.7167
Training loss: 2.401358 SDR vocals: 2.0094
Training loss: 2.358802 SDR vocals: 1.3975
Training loss: 2.351327 SDR vocals: 1.6052
Training loss: 2.350673 SDR vocals: 1.9241
Training loss: 2.311954 SDR vocals: 2.3140

@ZFTurbo hey, i apologize but the author reached out last night, and there was a bug in how i was folding the dimensions

should be fixed in the latest version! (now it is actually axial attention ๐Ÿ˜“ )

should train faster too

I restarted. It became much faster. 15 minutes now for the same epoch.

nice yea, that's axial attention at work

I'm currently training:

 model = BSRoformer(
        stereo=True,
        dim=128,
        depth=12,
        time_transformer_depth=1,
        freq_transformer_depth=1,
        freqs_per_bands=(
            32, 32, 32, 32, 32, 32, 32, 32,
            32, 32, 32, 32, 32, 32, 32, 32,
            32, 32, 32, 32, 32, 32, 32, 32,
            32, 32, 32, 32, 32, 32, 32, 33,
        )
    )

It's training very very slow. I use batch size 12. One epoch total 1000 batches requires 1 hour 40 minutes to finish. But loss is constantly decreasing now. From logs:

Training loss: 3.939826 SDR vocals: 1.7817
Training loss: 2.848435 SDR vocals: 1.4592
Training loss: 2.587866 SDR vocals: 1.5311
Training loss: 2.503295 SDR vocals: 1.3758
Training loss: 2.497599 SDR vocals: 1.9409
Training loss: 2.438752 SDR vocals: 1.8686
Training loss: 2.427704 SDR vocals: 1.4726
Training loss: 2.412679 SDR vocals: 1.4579
Training loss: 2.384547 SDR vocals: 1.7167
Training loss: 2.401358 SDR vocals: 2.0094
Training loss: 2.358802 SDR vocals: 1.3975
Training loss: 2.351327 SDR vocals: 1.6052
Training loss: 2.350673 SDR vocals: 1.9241
Training loss: 2.311954 SDR vocals: 2.3140

In my opinion, raw audio is 44100Hz, n_fft is 2048, so we will get 1025 bins, every bin is about 44100/2/1025 ~= 21.5Hz.

freq bins bins per band bands
f<1000 46.5 2 24
1000<f<2000 46.5 4 12
2000<f<4000 93.0 12 8
4000<f<8000 186.0 24 8
8000<f<16000 372.1 48 8
16000<f<22050 281.4 2

so total bands is 24 + 12 + 8 + 8 +8 + 2 = 62

Yes. I already fixed it:

band_split_params = (
      2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
      2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
      2, 2, 2, 2,
      4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4,
      12, 12, 12, 12, 12, 12, 12, 12,
      24, 24, 24, 24, 24, 24, 24, 24,
      48, 48, 48, 48, 48, 48, 48, 48,
      128, 129,
  )

think this should be solved

Hi @lucidrains + @ZFTurbo

I work with birds and have a audio dataset the sounds they make that I would like to train with this code, I have had success with other models that were authored for music. Please can I request for a way to train this model to be added to the code? Sorry i can not accomplish this myself.

@ZFTurbo hey, have you ever played around with complex neural networks? do you think it is worth doing a complex version of BS-Roformer. or a waste of time?

@ZFTurbo hey, have you ever played around with complex neural networks? do you think it is worth doing a complex version of BS-Roformer. or a waste of time?

No, I never tried.

@ZFTurbo ah ok, thought you may have tried it before, as it seems you are in the business of winning kaggle competitions

@ZFTurbo ok, maybe i'll think for a bit more until deciding whether to do a complex version of BS-Roformer

@lucidrains They have published a derived work recently, that is also using Roformer but instead of BandSplit, they've used Mel matrixing : https://arxiv.org/abs/2310.01809 that would be really great to see it reproduced too !

@jarredou oh wow, yes indeed

in the spirit of open source, PRs are always welcome, but feel free to open an issue and i can get back to this if no one else does

I'll share that I only open sourced this work because a precocious high schooler reached out asking me to consider it lol

@jarredou had some time this evening and knocked it out here

think this is may be my last open sourcing in the music separation space for a while

@lucidrains You're amazing ! โค๏ธ

@jarredou thanks for the sponsor! ๐Ÿ™ means a lot