Alescontrela/viper_rl

Small Issues and Fixes

xhlsgit opened this issue · 3 comments

Hello Ale. Thanks for your great work and it is pretty interesting! I am trying to reproduce some results and met with some problems. I fixed them and hope the fix info would be helpful for you!

  • For the requirements.txt
    • The Flax updated its version. But the current version itself has an unsolved problem, which will cause error. Here is the detail:
      • Error: If we installl the most updated version of flax=0.7.0, it will report error:
        • unexpected keyword argument 'restore_with_serialized_types'
      • Reason: Flax=0.7.0 itself is buggy.
        • Link: Set default types in Flax for Orbax restoration and add restore_with_serialized_types in preparation for an upcoming change. by @copybara-service in #3165
      • Fix: update requirements.txt with “flax==0.6.11”
  • One Codebase Bug:
    • Error: Code
    • Problem:
      • VideoGPT.call don’t have “text=text, text_mask=text_mask”.
    • Fix: remove “text=text, text_mask=text_mask” at line 47
  • Load model issue:
    • Problem:
      • I finished downloading the checkpoints and was trying to run this cmd:
        • python scripts/train_vqgan.py -o viper_rl_data/checkpoints/dmc_vqgan -c viper_rl/configs/vqgan/dmc.yaml
      • But got this error:
        • The target dict keys and state dict keys do not match, target dict contains keys {'ResnetBlock_9', 'Upsample_2', 'ResnetBlock_8'} which are not present in state dict at path ./vqgan_params/decoder
    • Fix:
      • change viper_rl/configs/vqgan.dmc.yaml:
        • ch_mult: [1, 2, 2, 2] -> ch_mult: [1, 2, 2]
        • patch_size: [8, 8] -> patch_size: [4, 4]

Thanks @xhlsgit for these fixes! I'll make a PR with these changes when i get the chance.

One similar problem occured with the Downsample blocks, I posted this in the unclosed issue . Wish somebody can give a check.

Now for some reason the g loss and ae loss as well as the vq loss begin sky rocketing at around 150k timesteps and i dont know why and they never come down again... vq loss even went up to 10^13... I assume that is not intended, does this maybe have to do with the changes proposed here or what is going on?