lucidrains/magvit2-pytorch

Training difficulties

LouisSerrano opened this issue · 51 comments

Hi, I am experiencing some difficulties during the training of magvit2. I don't know if I made some mistakes somewhere or where the problem might be coming from.

It seems that my understanding of the paper might me be erroneous, I tried with 2 codebooks of size 512 and I can't seem to fit the training data. The training is really unstable. I tried to replace the LFQ with a classical VQ and it was more stable and was able to converge.
What is the config that you tried for training the model ?

@LouisSerrano hey Louis and thank you for reporting this

that is disappointing to hear, as I had high hopes for LFQ. there is the possibility that I implemented it incorrectly, but the authors had already given it a code review, so that seems less likely. i'll be training the model tonight (regaining access to my deep learning rig) so i can see this instability for myself.

@LouisSerrano as a silver lining, i guess your experiments show the rest of the framework to be working ok. i can run a few toy experiments on LFQ this morning before running it on a big image dataset tonight and see what the issue is

@LouisSerrano just to make sure, are you on the latest version of the vector-quantize-pytorch library?

Ok thank you very much, it might simply be an error from my side, probably in the configuration. I used the model config that you suggested in the Readme.md, and for LFQ I used 1 codebook with size 512.

@lucidrains yes I am using 1.10.4

@LouisSerrano no that should be fine, you could try increasing to 2048 or 4096, but shouldn't make a big difference

@lucidrains As a precision I tried with a different dataset, a smaller one actually, but which should be less challenging than the ones from the paper.

@LouisSerrano this is kind of a stab in the dark, but would you like to try adding lfq_activation = nn.Tanh(), # from torch import nn to your VideoTokenizer init? (v0.0.64)

they didn't have this in the paper, but i think it should make sense

@LouisSerrano let me do some toy experiments right now and make sure it isn't some obvious bug

@lucidrains Ok I am going to try to increase the codebook size, just in case. Sure, I can check with tanh activation.

@LouisSerrano thank you! 🙏

@lucidrains Thanks for the awesome work !

@LouisSerrano yea no problem! well, we'll see if this work pans out. thank you for attempting to replicate!

for the toy task, LFQ looks ok compared to VQ. Tanh won't work, but you can try Tanh x 10

lfq_activation = lambda x: torch.nn.functional.tanh(x) * 10,

Ok thanks I will try this ! I'll let you know if I encounter some issues. Also what kind of weights do you use for the commitment and entropy loss ?

@LouisSerrano i think for commitment loss it is ok to keep it at the value as regular VQ of 1., but i'm not sure about the respective per-sample and batch entropy

@LouisSerrano i just increased the batch entropy weight a bit, to what works for the toy task (just fashion mnist)

@lucidrains Ok great, thanks for the tips.

@LouisSerrano hey Louis, just noticed that you used the default layer structure from the readme. feel free to try the updated one and see how that fares with LFQ

would still be interested to know if tanh(x) * 10 helps resolve your previous instability, if you have an experiment in progress.

This was with my previous config. I benchmarked against fsq. I also show the aux_loss, which is going crazy for lfq_tanh

Capture d’écran 2023-11-02 à 16 51 12

@LouisSerrano thank you! do you have the original unstable LFQ plot too? and wow, you actually gave FSQ a test drive; what level settings did you use for it?

@LouisSerrano did you compare it to baseline VQ by any chance?

that plot for FSQ looks very encouraging, maybe i'll add it tonight to the repository to further research. thanks again for running the experiment and sharing it

I think my main concern was on the auxiliary loss, I did not look to much into details but I assumed that somehow the model struggled to have a diverse codebook and good reconstructions
Capture d’écran 2023-11-02 à 17 02 35

@lucidrains I did not use VQ yet but I am gonna launch it tonight. So I'll let you know when I get the results. Also, I tested two recommended configurations for fsq: levels_big = [8,8,8,6,5] and levels_mid = [8,5,5,5]. On my simple problem I would have expected similar results for both and it is the case, so I am happy with it !

@lucidrains I will give a shot to your config, thanks again !

@LouisSerrano ok, let me know how VQ goes on your end! i'm hoping it is worse than both LFQ and FSQ, on that broken layer architecture you used (which is still a good benchmark)

if VQ outperforms both LFQ and FSQ on both the old and new layer architecture, then i don't know what to do. that would be the worst outcome

@LouisSerrano i'll be running experiments tonight and making sure image pretraining is seamless. will update on what i see

@lucidrains Don't worry, I will continue to inspect the code on my side. I made some adjustments to adapt the code for my dataset, so there still could be an issue somewhere only in my code. We are going to inspect the reconstructions on the same dataset with a different VQ implementation to see If I possibly made an error somewhere.

@LouisSerrano looking at your results, i'll put in some more work into FSQ today, make it break parity with LFQ. the magvit authors claim LFQ helped them break a SOTA while FSQ did not demonstrate that, but it is still early days.

@LouisSerrano added FSQ, hope it helps!

@lucidrains Thanks ! Yes I agree, it would be interesting to have the FSQ vs LFQ comparison for magvit2.

@LouisSerrano hey Louis, i had a big bug where i forgot to include quantizer parameters

may obsolete all your results, sorry!

@lucidrains Thanks for this, that's good to hear ! On my data, which might be a bit different from the one in the paper, I get fsq > vq > lfq so far. But this does not mean that lfq does not work.
Capture d’écran 2023-11-03 à 16 56 39

@LouisSerrano deleted my comment, because i realized i didn't make the number of codes equal lol. i'm not much of an experimentalist, as you can probably tell

is this with the updated repository with the projections into the codebook dimension being optimized? what is the codebook size you are using for all of them?

@LouisSerrano also, are the sampled reconstructions looking ok for your dataset?

@lucidrains I am using around 512 codes for each: [8, 6, 4, 3] for fsq, (512, 16) for vq and 2**9 for lfq.

thank you 🙏

Interesting! I moved over to stable-audio-tools and implemented FSQ there, and it has been working quite well as a VQ replacement for DAC. Thinking of trying out some ablations (including LFQ) once I do proper training runs.

@sekstini awesome!

@sekstini Looking forward to see your results !

stacked_fsq

Sneak peek of stacked FSQ indices corresponding to 1 second of 48kHz audio ^^

@sekstini got residual FSQ done this morning

you are both getting me excited

giving residual fsq a test drive with soundstream will let this run throughout the day and check on it through the phone

Doing some test runs comparing SFSQ (Stacked FSQ) with RVQ here:
https://wandb.ai/sekstini/smolgoose

Will maybe throw in LFQ later today.

hmm, not seeing much a difference, at least for soundstream. i'll keep trying though

for the toy task, LFQ looks ok compared to VQ. Tanh won't work, but you can try Tanh x 10

lfq_activation = lambda x: torch.nn.functional.tanh(x) * 10,

Hi @lucidrains , I couldn't find any instances where the lfq_activation is utilized. I wonder how this activation function works, and let me know if there's something I might be overlooking?

@lucidrains @LouisSerrano when using FSQ in place of LFQ, are you producing a latent image/volume of size of (len(levels), H, W)? And if so, to get a token sequence for FSQ, are you just converting from codes to indices and flattening (ie. (len(levels), H, W) -> (1, H, W) -> (1HW))?

@sekstini got residual FSQ done this morning

you are both getting me excited

me too :D