berthyf96/score_prior

Couldn't initialize models in demo notebooks

Closed this issue · 3 comments

Hi Berthy,

I really enjoyed your paper and would love to try out your model, but I got the following errors while trying the 3 notebooks and was wondering if you could help. Thank you very much!

My setup

  • WSL2 Ubuntu 22.04.2 LTS
  • python 3.10
  • tensorflow 2.12
  • jax/jaxlib 0.4.23 (The latest one before they remove jax.random.PRNGKeyArray)

Note about installing through conda.sh

  • Not sure if it's a WSL2 behavior, but when I run conda.sh the shell couldn't activate score_prior conda environment and complains about CondaError: Run 'conda init' before 'conda activate'. I ended up split the script in 2 parts to activate the env manually before installing the rest.
  • The line pip install --upgrade "jax[cuda11_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html would install the latest jax (0.4.25, released Feb 26, 2024), which has already removed jax.random.PRNGKeyArray that was used in posterior_sampling/realnvp_model.py and posterior_sampling/model_utils.py. I have to specify a different version to keep the original package unchanged.
  • The line pip install tensorflow-probability also seems to install the newest tensorflow-probability (0.23.0) that is not compatible with the specified tensorflow 2.12. I have to explicitly specify the version to be pip install tensorflow-probability==0.20

Errors when executing the notebooks

  1. 2d_posterior_sampling.ipynb
ValueError                                Traceback (most recent call last)
Cell In[6], line 25
     23 z = jax.random.normal(init_rng, (batch_size, dim))
     24 variables = model.init(init_rng, z, reverse=True)
---> 25 init_model_state, init_params = variables.pop('params')
     27 # Initialize optimizer.
     28 optimizer = optax.adam(learning_rate=learning_rate)

ValueError: too many values to unpack (expected 2)
  1. denoising.ipynb
ValueError                                Traceback (most recent call last)
Cell In[4], line 10
      7 sde, t0 = utils.get_sde(score_config)
      9 # Initialize score model.
---> 10 score_model, _, _ = score_mutils.init_model(jax.random.PRNGKey(0), score_config)
     11 score_state = score_mutils.State(
     12     step=0,
     13     model_state=None,
   (...)
     17     params_ema=None,
     18     rng=jax.random.PRNGKey(0))
     20 # Load score-model checkpoint.


File ~/score_prior/demos/../score_flow/models/utils.py:142, in init_model(rng, config, data, label)
    140 variables = model.init({'params': params_rng, 'dropout': dropout_rng}, init_input, init_label)
    141 # Variables is a `flax.FrozenDict`. It is immutable and respects functional programming
--> 142 init_model_state, initial_params = variables.pop('params')
    143 return model, init_model_state, initial_params

ValueError: too many values to unpack (expected 2)
  1. interferometry.ipynb
ValueError                                Traceback (most recent call last)
Cell In[6], line 10
      7 sde, t0 = utils.get_sde(score_config)
      9 # Initialize score model.
---> 10 score_model, _, _ = score_mutils.init_model(jax.random.PRNGKey(0), score_config)
     11 score_state = score_mutils.State(
     12     step=0,
     13     model_state=None,
   (...)
     17     params_ema=None,
     18     rng=jax.random.PRNGKey(0))
     20 # Load score-model checkpoint.

File ~/score_prior/demos/../score_flow/models/utils.py:142, in init_model(rng, config, data, label)
    140 variables = model.init({'params': params_rng, 'dropout': dropout_rng}, init_input, init_label)
    141 # Variables is a `flax.FrozenDict`. It is immutable and respects functional programming
--> 142 init_model_state, initial_params = variables.pop('params')
    143 return model, init_model_state, initial_params

ValueError: too many values to unpack (expected 2)

Thank you for bringing these to my attention! I created a branch that should solve things. Would you mind checking out the branch and confirming if the installation flows better and the model initialization errors go away?

Installing through conda.sh

  • Moved the first few lines creating/activating the environment into the README and kept the rest in conda.sh.
  • Specified pip install tensorflow-probability==0.20 as you suggested.
  • Changed all instances of jax.random.PRNGKeyArray to jax.Array, so you can use the latest 0.4.25 version of JAX.

Model initialization errors

  • Changed variables.pop('params') to flax.core.pop(variables, 'params') (I'm guessing variables.pop behavior changed in recent versions of Flax).

Now my environment uses JAX v0.4.25 and Flax v0.8.1, whereas before it used JAX v0.4.11 and Flax v0.6.10.

Hi Berthy,

Thanks for the reply! I tested the branch and it worked out of the box! I'm leaving my setup here as a reference.

My setup

  • WSL2 Ubuntu 22.04.2 LTS
  • python 3.10.13
  • cudatoolkit 11.8.0
  • diffrax 0.5.0
  • flax 0.8.1
  • jax/jaxlib 0.4.25
  • numpy 1.23.5
  • optax 0.1.9
  • tensorflow 2.12.1
  • tensorflow-probability 0.20.0

Small notes for the demo notebook

  1. I know you mentioned it in README but it'll be more straight forward if there's a reminder in denoising.ipynb and interferometry,ipynb for people to download the checkpoints from Box. Otherwise the initialization reading a non-existent path to the checkpoint would only throw a 'NoneType' object is not a mapping error that is a bit confusing.
  2. I wasn't aware of the resource requirement until denoising.ipynb threw me an OOM error. I have an NVIDIA Quadro P5000 GPU with 16GB RAM that I thought was moderate, but I wonder if it makes sense to have a smaller DPI model as a demo so that people with less resource could at least test it out.
  3. The DPI optimization part in denoising.ipynb says it only took 0.2 sec / step when you executed it and I was shocked. It took 10 min for the 1 step and then drops to 5 sec / step on my machine. Does it just get faster with more steps (I gave up after 2000 steps), or you're using some powerful hardward, or some of my softward acceleraiton like XLA isn't properly working (I did get some xla/slow_operation_alarm in the denoising notebook)?

Hi Berthy,

Thanks for the reply! I tested the branch and it worked out of the box! I'm leaving my setup here as a reference.

My setup

  • WSL2 Ubuntu 22.04.2 LTS
  • python 3.10.13
  • cudatoolkit 11.8.0
  • diffrax 0.5.0
  • flax 0.8.1
  • jax/jaxlib 0.4.25
  • numpy 1.23.5
  • optax 0.1.9
  • tensorflow 2.12.1
  • tensorflow-probability 0.20.0

Small notes for the demo notebook

  1. I know you mentioned it in README but it'll be more straight forward if there's a reminder in denoising.ipynb and interferometry,ipynb for people to download the checkpoints from Box. Otherwise the initialization reading a non-existent path to the checkpoint would only throw a 'NoneType' object is not a mapping error that is a bit confusing.
  2. I wasn't aware of the resource requirement until denoising.ipynb threw me an OOM error. I have an NVIDIA Quadro P5000 GPU with 16GB RAM that I thought was moderate, but I wonder if it makes sense to have a smaller DPI model as a demo so that people with less resource could at least test it out.
  3. The DPI optimization part in denoising.ipynb says it only took 0.2 sec / step when you executed it and I was shocked. It took 10 min for the 1 step and then drops to 5 sec / step on my machine. Does it just get faster with more steps (I gave up after 2000 steps), or you're using some powerful hardward, or some of my softward acceleraiton like XLA isn't properly working (I did get some xla/slow_operation_alarm in the denoising notebook)?

Thank you for trying it out and confirming!

Good points! I made a commit that notes the Box link in interferometry.ipynb and denoising.ipynb and notes the computational cost in denoising.ipynb.

I recommend starting with interferometry.ipynb (or setting config.optim.prior = 'dsm' in denoising.ipynb), as that uses the much faster ELBO surrogate. The first step of DPI optimization always takes longer because it involves JIT compilation. Subsequent steps after JIT compilation should be relatively faster. For denoising.ipynb, I also found that the first step took maybe 10 minutes, and the subsequent steps took about 1 min/step. It's very possible that time would go up for different hardware.