Test failure in `scico.jax` with latest `jax` version 0.4.29
bwohlberg opened this issue · 1 comments
bwohlberg commented
Jax release 0.4.29 appears to again have broken a component of scico.jax
(full log)
============================= test session starts ==============================
platform linux -- Python 3.10.14, pytest-8.2.2, pluggy-1.5.0
rootdir: /home/runner/work/scico/scico
configfile: pytest.ini
testpaths: scico/test, docs
plugins: split-0.8.2
collected 3329 items / 3 skipped
scico/test/flax/test_apply.py ....... [ 0%]
scico/test/flax/test_checkpoints.py .... [ 0%]
scico/test/flax/test_clu.py ..... [ 0%]
scico/test/flax/test_examples_flax.py ss..ssssF......................... [ 1%]
..... [ 1%]
scico/test/flax/test_flax.py .......................... [ 2%]
[...]
=================================== FAILURES ===================================
__________________________ test_blur_data_generation ___________________________
> ???
_mt19937.pyx:180:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
/opt/hostedtoolcache/Python/3.10.14/x64/lib/python3.10/site-packages/jax/_src/core.py:766: in __index__
raise self.aval._index(self)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = ShapedArray(int64[])
arg = Traced<ShapedArray(int64[])>with<BatchTrace(level=1/0)> with
val = Array([0], dtype=int64)
batch_dim = 0
def error(self, arg):
> raise TracerIntegerConversionError(arg)
E jax.errors.TracerIntegerConversionError: The __index__() method was called on traced array with shape int64[].
E This BatchTracer with object id 140001038232176 was created on line:
E /home/runner/work/scico/scico/scico/test/flax/test_examples_flax.py:154 (random_img_gen)
E See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerIntegerConversionError
/opt/hostedtoolcache/Python/3.10.14/x64/lib/python3.10/site-packages/jax/_src/core.py:1508: TracerIntegerConversionError
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "_mt19937.pyx", line 180, in numpy.random._mt19937.MT19937._legacy_seeding
File "/opt/hostedtoolcache/Python/3.10.14/x64/lib/python3.10/site-packages/jax/_src/core.py", line 766, in __index__
raise self.aval._index(self)
File "/opt/hostedtoolcache/Python/3.10.14/x64/lib/python3.10/site-packages/jax/_src/core.py", line 1508, in error
raise TracerIntegerConversionError(arg)
jax.errors.TracerIntegerConversionError: The __index__() method was called on traced array with shape int64[].
This BatchTracer with object id 140001038232176 was created on line:
/home/runner/work/scico/scico/scico/test/flax/test_examples_flax.py:154 (random_img_gen)
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerIntegerConversionError
During handling of the above exception, another exception occurred:
jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.
The above exception was the direct cause of the following exception:
def test_blur_data_generation():
N = 32
nimg = 8
n = 3 # convolution kernel size
blur_kernel = np.ones((n, n)) / (n * n)
def random_img_gen(seed, size, ndata):
np.random.seed(seed)
return np.random.randn(ndata, size, size, 1)
> img, blurn = generate_blur_data(nimg, N, blur_kernel, noise_sigma=0.01, imgfunc=random_img_gen)
scico/test/flax/test_examples_flax.py:157:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
scico/flax/examples/data_generation.py:318: in generate_blur_data
img = distributed_data_generation(imgfunc, size, nimg, False)
scico/flax/examples/data_generation.py:382: in distributed_data_generation
imgs = jax.vmap(imgenf, (0, None, None))(idx, size, ndata_per_proc)
scico/test/flax/test_examples_flax.py:154: in random_img_gen
np.random.seed(seed)
numpy/random/mtrand.pyx:4806: in numpy.random.mtrand.seed
???
numpy/random/mtrand.pyx:250: in numpy.random.mtrand.RandomState.seed
???
_mt19937.pyx:168: in numpy.random._mt19937.MT19937._legacy_seeding
???
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
> ???
E jax.errors.TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on traced array with shape int64[].
E This BatchTracer with object id 140001038232176 was created on line:
E /home/runner/work/scico/scico/scico/test/flax/test_examples_flax.py:154 (random_img_gen)
E See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerArrayConversionError
_mt19937.pyx:185: TracerArrayConversionError
=========================== short test summary info ============================
FAILED scico/test/flax/test_examples_flax.py::test_blur_data_generation - jax.errors.TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on traced array with shape int64[].
This BatchTracer with object id 140001038232176 was created on line:
/home/runner/work/scico/scico/scico/test/flax/test_examples_flax.py:154 (random_img_gen)
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerArrayConversionError
====== 1 failed, 3280 passed, 24 skipped, 27 xfailed in 294.89s (0:04:54) ======
Error: Process completed with exit code 1.
bwohlberg commented
scico/test/flax/test_examples_flax.py
tests are also failing on jax
0.4.28 (nominally supported according to current requirements.txt
) on GPU device.