Couldn't initialize models in demo notebooks
Closed this issue · 3 comments
Hi Berthy,
I really enjoyed your paper and would love to try out your model, but I got the following errors while trying the 3 notebooks and was wondering if you could help. Thank you very much!
My setup
- WSL2 Ubuntu 22.04.2 LTS
- python 3.10
- tensorflow 2.12
- jax/jaxlib 0.4.23 (The latest one before they remove jax.random.PRNGKeyArray)
Note about installing through conda.sh
- Not sure if it's a WSL2 behavior, but when I run
conda.sh
the shell couldn't activatescore_prior
conda environment and complains aboutCondaError: Run 'conda init' before 'conda activate'
. I ended up split the script in 2 parts to activate the env manually before installing the rest. - The line
pip install --upgrade "jax[cuda11_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
would install the latest jax (0.4.25, released Feb 26, 2024), which has already removedjax.random.PRNGKeyArray
that was used inposterior_sampling/realnvp_model.py
andposterior_sampling/model_utils.py
. I have to specify a different version to keep the original package unchanged. - The line
pip install tensorflow-probability
also seems to install the newesttensorflow-probability (0.23.0)
that is not compatible with the specifiedtensorflow 2.12
. I have to explicitly specify the version to bepip install tensorflow-probability==0.20
Errors when executing the notebooks
- 2d_posterior_sampling.ipynb
ValueError Traceback (most recent call last)
Cell In[6], line 25
23 z = jax.random.normal(init_rng, (batch_size, dim))
24 variables = model.init(init_rng, z, reverse=True)
---> 25 init_model_state, init_params = variables.pop('params')
27 # Initialize optimizer.
28 optimizer = optax.adam(learning_rate=learning_rate)
ValueError: too many values to unpack (expected 2)
- denoising.ipynb
ValueError Traceback (most recent call last)
Cell In[4], line 10
7 sde, t0 = utils.get_sde(score_config)
9 # Initialize score model.
---> 10 score_model, _, _ = score_mutils.init_model(jax.random.PRNGKey(0), score_config)
11 score_state = score_mutils.State(
12 step=0,
13 model_state=None,
(...)
17 params_ema=None,
18 rng=jax.random.PRNGKey(0))
20 # Load score-model checkpoint.
File ~/score_prior/demos/../score_flow/models/utils.py:142, in init_model(rng, config, data, label)
140 variables = model.init({'params': params_rng, 'dropout': dropout_rng}, init_input, init_label)
141 # Variables is a `flax.FrozenDict`. It is immutable and respects functional programming
--> 142 init_model_state, initial_params = variables.pop('params')
143 return model, init_model_state, initial_params
ValueError: too many values to unpack (expected 2)
- interferometry.ipynb
ValueError Traceback (most recent call last)
Cell In[6], line 10
7 sde, t0 = utils.get_sde(score_config)
9 # Initialize score model.
---> 10 score_model, _, _ = score_mutils.init_model(jax.random.PRNGKey(0), score_config)
11 score_state = score_mutils.State(
12 step=0,
13 model_state=None,
(...)
17 params_ema=None,
18 rng=jax.random.PRNGKey(0))
20 # Load score-model checkpoint.
File ~/score_prior/demos/../score_flow/models/utils.py:142, in init_model(rng, config, data, label)
140 variables = model.init({'params': params_rng, 'dropout': dropout_rng}, init_input, init_label)
141 # Variables is a `flax.FrozenDict`. It is immutable and respects functional programming
--> 142 init_model_state, initial_params = variables.pop('params')
143 return model, init_model_state, initial_params
ValueError: too many values to unpack (expected 2)
Thank you for bringing these to my attention! I created a branch that should solve things. Would you mind checking out the branch and confirming if the installation flows better and the model initialization errors go away?
Installing through conda.sh
- Moved the first few lines creating/activating the environment into the README and kept the rest in
conda.sh
. - Specified
pip install tensorflow-probability==0.20
as you suggested. - Changed all instances of
jax.random.PRNGKeyArray
tojax.Array
, so you can use the latest 0.4.25 version of JAX.
Model initialization errors
- Changed
variables.pop('params')
toflax.core.pop(variables, 'params')
(I'm guessingvariables.pop
behavior changed in recent versions of Flax).
Now my environment uses JAX v0.4.25 and Flax v0.8.1, whereas before it used JAX v0.4.11 and Flax v0.6.10.
Hi Berthy,
Thanks for the reply! I tested the branch and it worked out of the box! I'm leaving my setup here as a reference.
My setup
- WSL2 Ubuntu 22.04.2 LTS
- python 3.10.13
- cudatoolkit 11.8.0
- diffrax 0.5.0
- flax 0.8.1
- jax/jaxlib 0.4.25
- numpy 1.23.5
- optax 0.1.9
- tensorflow 2.12.1
- tensorflow-probability 0.20.0
Small notes for the demo notebook
- I know you mentioned it in README but it'll be more straight forward if there's a reminder in
denoising.ipynb
andinterferometry,ipynb
for people to download the checkpoints from Box. Otherwise the initialization reading a non-existent path to the checkpoint would only throw a'NoneType' object is not a mapping
error that is a bit confusing. - I wasn't aware of the resource requirement until
denoising.ipynb
threw me an OOM error. I have an NVIDIA Quadro P5000 GPU with 16GB RAM that I thought was moderate, but I wonder if it makes sense to have a smaller DPI model as a demo so that people with less resource could at least test it out. - The DPI optimization part in
denoising.ipynb
says it only took 0.2 sec / step when you executed it and I was shocked. It took 10 min for the 1 step and then drops to 5 sec / step on my machine. Does it just get faster with more steps (I gave up after 2000 steps), or you're using some powerful hardward, or some of my softward acceleraiton like XLA isn't properly working (I did get somexla/slow_operation_alarm
in the denoising notebook)?
Hi Berthy,
Thanks for the reply! I tested the branch and it worked out of the box! I'm leaving my setup here as a reference.
My setup
- WSL2 Ubuntu 22.04.2 LTS
- python 3.10.13
- cudatoolkit 11.8.0
- diffrax 0.5.0
- flax 0.8.1
- jax/jaxlib 0.4.25
- numpy 1.23.5
- optax 0.1.9
- tensorflow 2.12.1
- tensorflow-probability 0.20.0
Small notes for the demo notebook
- I know you mentioned it in README but it'll be more straight forward if there's a reminder in
denoising.ipynb
andinterferometry,ipynb
for people to download the checkpoints from Box. Otherwise the initialization reading a non-existent path to the checkpoint would only throw a'NoneType' object is not a mapping
error that is a bit confusing.- I wasn't aware of the resource requirement until
denoising.ipynb
threw me an OOM error. I have an NVIDIA Quadro P5000 GPU with 16GB RAM that I thought was moderate, but I wonder if it makes sense to have a smaller DPI model as a demo so that people with less resource could at least test it out.- The DPI optimization part in
denoising.ipynb
says it only took 0.2 sec / step when you executed it and I was shocked. It took 10 min for the 1 step and then drops to 5 sec / step on my machine. Does it just get faster with more steps (I gave up after 2000 steps), or you're using some powerful hardward, or some of my softward acceleraiton like XLA isn't properly working (I did get somexla/slow_operation_alarm
in the denoising notebook)?
Thank you for trying it out and confirming!
Good points! I made a commit that notes the Box link in interferometry.ipynb
and denoising.ipynb
and notes the computational cost in denoising.ipynb
.
I recommend starting with interferometry.ipynb
(or setting config.optim.prior = 'dsm'
in denoising.ipynb
), as that uses the much faster ELBO surrogate. The first step of DPI optimization always takes longer because it involves JIT compilation. Subsequent steps after JIT compilation should be relatively faster. For denoising.ipynb
, I also found that the first step took maybe 10 minutes, and the subsequent steps took about 1 min/step. It's very possible that time would go up for different hardware.