AttributeError: module 'jax.random' has no attribute 'PRNGKeyArray'
Aki1991 opened this issue · 9 comments
Hi all, I am trying to fine tune our model using owl_vit model.
But when I try to run it, I get this error, AttributeError: module 'jax.random' has no attribute 'PRNGKeyArray'
. Jax version I am using is 0.4.30. If I use jax 0.4.23, it works but then it is not using GPU while training which slows down the training a lot. Is there a way I can use 0.4.30 version of jax and solve this error?
If I change the PRNGKeyArray with key, at later stage I get an error,
Traceback (most recent call last):
File "/home/user/anaconda3/envs/owl_gpu/lib/python3.10/runpy.py", line 196, in _run_module_as_main
return _run_code(code, main_globals, None,
File "/home/user/anaconda3/envs/owl_gpu/lib/python3.10/runpy.py", line 86, in _run_code
exec(code, run_globals)
File "/home/user/Akash/Owl/scenic/scenic/projects/owl_vit/main.py", line 61, in <module>
app.run(main=main)
File "/home/user/Akash/Owl/scenic/scenic/app.py", line 68, in run
app.run(functools.partial(_run_main, main=main))
File "/home/user/anaconda3/envs/owl_gpu/lib/python3.10/site-packages/absl/app.py", line 308, in run
_run_main(main, args)
File "/home/user/anaconda3/envs/owl_gpu/lib/python3.10/site-packages/absl/app.py", line 254, in _run_main
sys.exit(main(argv))
File "/home/user/Akash/Owl/scenic/scenic/app.py", line 109, in _run_main
main(rng=rng, config=config, workdir=workdir, writer=writer)
File "/home/user/Akash/Owl/scenic/scenic/projects/owl_vit/main.py", line 51, in main
trainer.train(
File "/home/user/Akash/Owl/scenic/scenic/projects/owl_vit/trainer.py", line 218, in train
gflops) = train_utils.initialize_model(
File "/home/user/Akash/Owl/scenic/scenic/train_lib/train_utils.py", line 187, in initialize_model
flops = debug_utils.compute_flops(
File "/home/user/Akash/Owl/scenic/scenic/common_lib/debug_utils.py", line 139, in compute_flops
flops = analysis['flops']
TypeError: 'NoneType' object is not subscriptable
Can anyone suggest what can I do here? Thank you.
UPDATE: I installed all libraries with proper versions and made it work with GPU with jax==0.4.23 but I am still getting the error mentioned above,
flops = analysis['flops']
TypeError: 'NoneType' object is not subscriptable
Same issue there.
AttributeError: module 'jax.random' has no attribute 'PRNGKeyArray'
can be solved by changing jax.random.PRNGKeyArray
with jax.Array
.
But it is not solving
flops = analysis['flops']
TypeError: 'NoneType' object is not subscriptable
This should be fixed in ott-jax==0.3.1
I am using same version of ott-jax==0.3.1, still same error.
sorry, it should be running "pip install ott-jax==0.4.5" firstly, if you have an error about "transport" then run "pip install ott-jax==0.3.1"
Yes I am getting the "transport" error, that's why I am using ott-jax==0.3.1. And that leads to the error:
flops = analysis['flops']
TypeError: 'NoneType' object is not subscriptable
pip install ott-jax==0.2.0 works