Error in `tile_linalg_jacobi` for hydrogen
blazejba opened this issue · 3 comments
blazejba commented
Computing two hydrogen atoms using cpu backend works:
python density_functional_theory.py -H -backend cpu -float32
but with ipu backend
python density_functional_theory.py -H -backend ipu -float32
... an error is thrown:
Traceback (most recent call last):
File "density_functional_theory.py", line 1440, in <module>
elif args.H: recompute(args, None, 0, 0, our_fun=jax_dft, str=[["H", (0, 0, 0)],
File "density_functional_theory.py", line 1095, in recompute
energies, our_energy, our_hlgap, t_us, t_main_loop, us_hlgap = our_fun(str)
File "density_functional_theory.py", line 1205, in jax_dft
vals = density_functional_theory(atom_positions)
File "density_functional_theory.py", line 665, in density_functional_theory
vals = jax.jit(_do_compute, static_argnums=(10,11), device=device_1) ( density_matrix, kinetic, nuclear, overlap,
File "density_functional_theory.py", line 148, in _do_compute
vals = jax.lax.fori_loop(0, args.its, iter, vals)
File "density_functional_theory.py", line 782, in iter
eigvects = _eigh(generalized_hamiltonian )[1]
File "density_functional_theory.py", line 1002, in _eigh
eigvects, eigvals = ipu_eigh(x, sort_eigenvalues=True, num_iters=12)
File "/nethome/blazejb/.venvs/3.2.0+1277/3.2.0+1277_poptorch/lib/python3.8/site-packages/tessellate_ipu/linalg/tile_linalg_jacobi.py", line 310, in ipu_eigh
A, VT = ipu_jacobi_eigh(x, num_iters=num_iters)
File "/nethome/blazejb/.venvs/3.2.0+1277/3.2.0+1277_poptorch/lib/python3.8/site-packages/tessellate_ipu/linalg/tile_linalg_jacobi.py", line 215, in ipu_jacobi_eigh
Apcols, Aqcols, Vpcols, Vqcols = jax.lax.fori_loop(
File "/nethome/blazejb/.venvs/3.2.0+1277/3.2.0+1277_poptorch/lib/python3.8/site-packages/tessellate_ipu/linalg/tile_linalg_jacobi.py", line 213, in <lambda>
eigh_iteration_fn = lambda _, x: ipu_jacobi_eigh_iteration(x, Atiles, Vtiles)
File "/nethome/blazejb/.venvs/3.2.0+1277/3.2.0+1277_poptorch/lib/python3.8/site-packages/tessellate_ipu/linalg/tile_linalg_jacobi.py", line 180, in ipu_jacobi_eigh_iteration
Apcols, Aqcols = tile_rotate_columns(Apcols, Aqcols, rotset)
File "/nethome/blazejb/.venvs/3.2.0+1277/3.2.0+1277_poptorch/lib/python3.8/site-packages/tessellate_ipu/linalg/tile_linalg_jacobi.py", line 276, in tile_rotate_columns
assert len(pcols_indices_new) == halfN
AssertionError
AlexanderMath commented
ipu_eigh
only works for d>=6
we fixed in nanoDFT here. Same fix should also work in density_functional_theory.py
.
The fix is just a switch:
def eigh(x):
if d <= 6: jnp.linalg.eigh(x)
else: custom_ipu_eigh(x)
AlexanderMath commented
This PR fixes it. Feel free to accept/close this issue.
balancap commented
Thanks @blazejba for finding the issue! Just opened a bug ticket on TessellateIPU so we can solve it there: graphcore-research/tessellate-ipu#12
@AlexanderMath Thanks for the PR. Let's merge that for now until there is a fix in TessellateIPU.