Cannot run demo, possible incompatibility with latest Jax
daniel-trejobanos opened this issue · 4 comments
Dear all,
I am trying to run the demo examples, but I run in the following error
ImportError Traceback (most recent call last)
Input In [22], in <cell line: 1>()
----> 1 import bayesnewton
2 import objax
3 import numpy as np
File ~/opt/anaconda3/envs/aurora/lib/python3.9/site-packages/bayesnewton/init.py:1, in
----> 1 from . import (
2 kernels,
3 utils,
4 ops,
5 likelihoods,
6 models,
7 basemodels,
8 inference,
9 cubature
10 )
13 def build_model(model, inf, name='GPModel'):
14 return type(name, (inf, model), {})
File ~/opt/anaconda3/envs/aurora/lib/python3.9/site-packages/bayesnewton/kernels.py:5, in
3 import jax.numpy as np
4 from jax.scipy.linalg import cho_factor, cho_solve, block_diag, expm
----> 5 from jax.ops import index_add, index
6 from .utils import scaled_squared_euclid_dist, softplus, softplus_inv, rotation_matrix
7 from warnings import warn
ImportError: cannot import name 'index_add' from 'jax.ops' (/Users/Daniel/opt/anaconda3/envs/aurora/lib/python3.9/site-packages/jax/ops/init.py)
I think its related to this from the Jax website:
The functions jax.ops.index_update, jax.ops.index_add, etc., which were deprecated in JAX 0.2.22, have been removed. Please use the jax.numpy.ndarray.at property on JAX arrays instead.
I now realise that your pip installation asks for a specific jax version, which is a bit problematic for me, given that I am running on a M1 and installed jax via condaforge, I am not sure I can match to a compatible version, I will try and let you know if I succeed.
I managed to downgrade jax, but there is no jaxlib 0.1.60 available in condaforge, seems like it could be the source of this bug I get when trying to load objax 1.31:
TypeError Traceback (most recent call last)
Input In [22], in <cell line: 1>()
----> 1 import bayesnewton
2 import objax
3 import numpy as np
File ~/opt/anaconda3/envs/aurora/lib/python3.9/site-packages/bayesnewton/init.py:1, in
----> 1 from . import (
2 kernels,
3 utils,
4 ops,
5 likelihoods,
6 models,
7 basemodels,
8 inference,
9 cubature
10 )
13 def build_model(model, inf, name='GPModel'):
14 return type(name, (inf, model), {})
File ~/opt/anaconda3/envs/aurora/lib/python3.9/site-packages/bayesnewton/kernels.py:1, in
----> 1 import objax
2 from jax import vmap
3 import jax.numpy as np
File ~/opt/anaconda3/envs/aurora/lib/python3.9/site-packages/objax/init.py:17, in
1 # Copyright 2020 Google LLC
2 #
3 # Licensed under the Apache License, Version 2.0 (the "License");
(...)
12 # See the License for the specific language governing permissions and
13 # limitations under the License.
15 import sys
---> 17 from ._patch_jax import *
19 pass # To avoid reordering imports from above
21 from . import functional
File ~/opt/anaconda3/envs/aurora/lib/python3.9/site-packages/objax/_patch_jax.py:20, in
16 all = []
18 from typing import Union, Sequence, Tuple, Callable, Optional
---> 20 import jax.numpy as jn
22 from .typing import JaxArray
23 from .util import re_sign
File ~/opt/anaconda3/envs/aurora/lib/python3.9/site-packages/jax/init.py:93, in
89 from .version import version
91 # These submodules are separate because they are in an import cycle with
92 # jax and rely on the names imported above.
---> 93 from . import image
94 from . import lax
95 from . import nn
File ~/opt/anaconda3/envs/aurora/lib/python3.9/site-packages/jax/image/init.py:18, in
15 """Common functions for neural network libraries."""
17 # flake8: noqa: F401
---> 18 from jax._src.image.scale import (
19 resize,
20 ResizeMethod,
21 scale_and_translate,
22 )
File ~/opt/anaconda3/envs/aurora/lib/python3.9/site-packages/jax/_src/image/scale.py:20, in
17 from typing import Callable, Sequence, Union
19 from jax import jit
---> 20 from jax import lax
21 from jax import numpy as jnp
22 import numpy as np
File ~/opt/anaconda3/envs/aurora/lib/python3.9/site-packages/jax/lax/init.py:324, in
291 from jax._src.lax.lax import (_reduce_sum, _reduce_max, _reduce_min, _reduce_or,
292 _reduce_and, _reduce_window_sum, _reduce_window_max,
293 _reduce_window_min, _reduce_window_prod,
(...)
298 _upcast_fp16_for_computation, _broadcasting_shape_rule,
299 _eye, _tri, _delta, _ones, _zeros, _dilate_shape)
300 from jax._src.lax.control_flow import (
301 associative_scan,
302 cond,
(...)
322 while_p,
323 )
--> 324 from jax._src.lax.fft import (
325 fft,
326 fft_p,
327 )
328 from jax._src.lax.parallel import (
329 all_gather,
330 all_to_all,
(...)
346 xeinsum,
347 )
348 from jax._src.lax.other import (
349 conv_general_dilated_patches
350 )
File ~/opt/anaconda3/envs/aurora/lib/python3.9/site-packages/jax/_src/lax/fft.py:87, in
83 n = fft_lengths[-1]
84 return y[..., : n//2 + 1]
86 @partial(jit, static_argnums=1)
---> 87 def _rfft_transpose(t, fft_lengths):
88 # The transpose of RFFT can't be expressed only in terms of irfft. Instead of
89 # manually building up larger twiddle matrices (which would increase the
90 # asymptotic complexity and is also rather complicated), we rely JAX to
91 # transpose a naive RFFT implementation.
92 dummy_shape = t.shape[:-len(fft_lengths)] + fft_lengths
93 dummy_primal = ShapeDtypeStruct(dummy_shape, _real_dtype(t.dtype))
File ~/opt/anaconda3/envs/aurora/lib/python3.9/site-packages/jax/api.py:184, in jit(fun, static_argnums, device, backend, donate_argnums)
129 """Sets up fun
for just-in-time compilation with XLA.
130
131 Args:
(...)
181 -0.85743 -0.78232 0.76827 0.59566 ]
182 """
183 if FLAGS.experimental_cpp_jit and config.omnistaging_enabled:
--> 184 return _cpp_jit(fun, static_argnums, device, backend, donate_argnums)
185 else:
186 return _python_jit(fun, static_argnums, device, backend, donate_argnums)
File ~/opt/anaconda3/envs/aurora/lib/python3.9/site-packages/jax/api.py:370, in cpp_jit(fun, static_argnums, device, backend, donate_argnums)
367 return config.read("jax_disable_jit")
369 static_argnums = (0,) + tuple(i + 1 for i in static_argnums)
--> 370 cpp_jitted_f = jax_jit.jit(fun, cache_miss, get_device_info,
371 get_jax_enable_x64, get_jax_disable_jit_flag,
372 static_argnums_)
374 # TODO(mattjj): make cpp callable follow descriptor protocol for bound methods
375 @wraps(fun)
376 @api_boundary
377 def f_jitted(*args, **kwargs):
TypeError: jit(): incompatible function arguments. The following argument types are supported:
1. (fun: function, cache_miss: function, get_device: function, static_argnums: List[int], static_argnames: List[str] = [], donate_argnums: List[int] = [], cache: jaxlib.xla_extension.CompiledFunctionCache = None) -> object
Invoked with: <function _rfft_transpose at 0x7f93913f20d0>, <function _cpp_jit..cache_miss at 0x7f93913f2160>, <function _cpp_jit..get_device_info at 0x7f93913f21f0>, <function _cpp_jit..get_jax_enable_x64 at 0x7f93913f2280>, <function _cpp_jit..get_jax_disable_jit_flag at 0x7f93913f2310>, (0, 2)
Hi,
Sorry for the slow response and apologies that you've been having issues with the package versions. This is indeed frustrating. I would love to update the package to work with the most recent versions, but I don't currently have the spare time.
I am using an M1 mac and things are working OK for me, but I'm not using condaforge.
The index_update
issue should be fairly easy to fix. However, I recall seeing some performance issues when I tried updating objax in the past, and I never managed to debug the issue. I hope to get around to this at some point in the future.