google-research/scenic

AttributeError: module 'jax.random' has no attribute 'PRNGKeyArray'

Aki1991 opened this issue · 9 comments

Hi all, I am trying to fine tune our model using owl_vit model.

But when I try to run it, I get this error, AttributeError: module 'jax.random' has no attribute 'PRNGKeyArray'. Jax version I am using is 0.4.30. If I use jax 0.4.23, it works but then it is not using GPU while training which slows down the training a lot. Is there a way I can use 0.4.30 version of jax and solve this error?

If I change the PRNGKeyArray with key, at later stage I get an error,

Traceback (most recent call last):
 File "/home/user/anaconda3/envs/owl_gpu/lib/python3.10/runpy.py", line 196, in _run_module_as_main
   return _run_code(code, main_globals, None,
 File "/home/user/anaconda3/envs/owl_gpu/lib/python3.10/runpy.py", line 86, in _run_code
   exec(code, run_globals)
 File "/home/user/Akash/Owl/scenic/scenic/projects/owl_vit/main.py", line 61, in <module>
   app.run(main=main)
 File "/home/user/Akash/Owl/scenic/scenic/app.py", line 68, in run
   app.run(functools.partial(_run_main, main=main))
 File "/home/user/anaconda3/envs/owl_gpu/lib/python3.10/site-packages/absl/app.py", line 308, in run
   _run_main(main, args)
 File "/home/user/anaconda3/envs/owl_gpu/lib/python3.10/site-packages/absl/app.py", line 254, in _run_main
   sys.exit(main(argv))
 File "/home/user/Akash/Owl/scenic/scenic/app.py", line 109, in _run_main
   main(rng=rng, config=config, workdir=workdir, writer=writer)
 File "/home/user/Akash/Owl/scenic/scenic/projects/owl_vit/main.py", line 51, in main
   trainer.train(
 File "/home/user/Akash/Owl/scenic/scenic/projects/owl_vit/trainer.py", line 218, in train
   gflops) = train_utils.initialize_model(
 File "/home/user/Akash/Owl/scenic/scenic/train_lib/train_utils.py", line 187, in initialize_model
   flops = debug_utils.compute_flops(
 File "/home/user/Akash/Owl/scenic/scenic/common_lib/debug_utils.py", line 139, in compute_flops
   flops = analysis['flops']
TypeError: 'NoneType' object is not subscriptable

Can anyone suggest what can I do here? Thank you.

UPDATE: I installed all libraries with proper versions and made it work with GPU with jax==0.4.23 but I am still getting the error mentioned above,

    flops = analysis['flops']
TypeError: 'NoneType' object is not subscriptable

Same issue there.

AttributeError: module 'jax.random' has no attribute 'PRNGKeyArray' can be solved by changing jax.random.PRNGKeyArray with jax.Array.

But it is not solving

flops = analysis['flops']
TypeError: 'NoneType' object is not subscriptable

This should be fixed in ott-jax==0.3.1

I am using same version of ott-jax==0.3.1, still same error.

sorry, it should be running "pip install ott-jax==0.4.5" firstly, if you have an error about "transport" then run "pip install ott-jax==0.3.1"

Yes I am getting the "transport" error, that's why I am using ott-jax==0.3.1. And that leads to the error:

flops = analysis['flops']
TypeError: 'NoneType' object is not subscriptable

pip install ott-jax==0.2.0 works