LinearSolveTest.test_solve_sparse fails with jax 0.4.26
GaetanLepage opened this issue · 1 comments
GaetanLepage commented
Context: updating jax
in nixpkgs: NixOS/nixpkgs#291705 (comment)
One of the optax
tests fail when ran with the latest jax
(0.4.26):
============================= test session starts ==============================
platform linux -- Python 3.11.9, pytest-8.1.1, pluggy-1.4.0
rootdir: /build/source
plugins: xdist-3.5.0
48 workers [561 items] m
...s.................................................................... [ 12%]
........................................................................ [ 25%]
.............................................s......s............s...... [ 38%]
............s........s.................................................. [ 51%]
.................F...................................................... [ 64%]
........................................................................ [ 77%]
........................................................................ [ 89%]
......................................................... [100%]
=================================== FAILURES ===================================
______________________ LinearSolveTest.test_solve_sparse _______________________
[gw24] linux -- Python 3.11.9 /nix/store/lpi16513bai8kg2bd841745vzk72475x-python3-3.11.9/bin/python3.11
self = <linear_solve_test.LinearSolveTest testMethod=test_solve_sparse>
def test_solve_sparse(self):
rng = onp.random.RandomState(0)
# Matrix case.
A = rng.randn(5, 5)
b = rng.randn(5)
def matvec(x):
return jnp.dot(A, x)
x = linear_solve.solve_lu(matvec, b)
x2 = linear_solve.solve_normal_cg(matvec, b)
x3 = linear_solve.solve_gmres(matvec, b)
x4 = linear_solve.solve_bicgstab(matvec, b)
x5 = linear_solve.solve_iterative_refinement(matvec, b)
x6 = linear_solve.solve_qr(matvec, b)
self.assertArraysAllClose(x, x2, atol=1e-4)
self.assertArraysAllClose(x, x3, atol=1e-4)
> self.assertArraysAllClose(x, x4, atol=1e-4)
tests/linear_solve_test.py:133:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
jaxopt/_src/test_util.py:292: in assertArraysAllClose
_assert_numpy_allclose(x, y, atol=atol, rtol=rtol, err_msg=err_msg)
jaxopt/_src/test_util.py:262: in _assert_numpy_allclose
onp.testing.assert_allclose(a, b, **kw, err_msg=err_msg)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
args = (<function assert_allclose.<locals>.compare at 0x7ffcd857af20>, array([-6.9443436, -1.9871655, 7.7470713, 7.654949 ,...87526],
dtype=float32), array([-6.9444494, -1.9872105, 7.7471952, 7.655079 , -7.0388584],
dtype=float32))
kwds = {'equal_nan': True, 'err_msg': '', 'header': 'Not equal to tolerance rtol=1e-06, atol=0.0001', 'verbose': True}
@wraps(func)
def inner(*args, **kwds):
with self._recreate_cm():
> return func(*args, **kwds)
E AssertionError:
E Not equal to tolerance rtol=1e-06, atol=0.0001
E
E Mismatched elements: 2 / 5 (40%)
E Max absolute difference: 0.0001297
E Max relative difference: 2.267556e-05
E x: array([-6.944344, -1.987165, 7.747071, 7.654949, -7.038753],
E dtype=float32)
E y: array([-6.944449, -1.987211, 7.747195, 7.655079, -7.038858],
E dtype=float32)
/nix/store/lpi16513bai8kg2bd841745vzk72475x-python3-3.11.9/lib/python3.11/contextlib.py:81: AssertionError
Any idea ?
GaetanLepage commented
Wrong repo. This issue is actually happening in jaxopt. Sorry for the inconvenience.