graphcore-research/pyscf-ipu

Error in `tile_linalg_jacobi` for hydrogen

blazejba opened this issue · 3 comments

Computing two hydrogen atoms using cpu backend works:
python density_functional_theory.py -H -backend cpu -float32

but with ipu backend
python density_functional_theory.py -H -backend ipu -float32

... an error is thrown:

Traceback (most recent call last):
  File "density_functional_theory.py", line 1440, in <module>
    elif args.H: recompute(args, None, 0, 0, our_fun=jax_dft, str=[["H", (0, 0, 0)],
  File "density_functional_theory.py", line 1095, in recompute
    energies, our_energy, our_hlgap, t_us, t_main_loop, us_hlgap = our_fun(str)
  File "density_functional_theory.py", line 1205, in jax_dft
    vals = density_functional_theory(atom_positions)
  File "density_functional_theory.py", line 665, in density_functional_theory
    vals = jax.jit(_do_compute, static_argnums=(10,11), device=device_1) ( density_matrix, kinetic, nuclear, overlap,
  File "density_functional_theory.py", line 148, in _do_compute
    vals = jax.lax.fori_loop(0, args.its, iter, vals)
  File "density_functional_theory.py", line 782, in iter
    eigvects = _eigh(generalized_hamiltonian )[1]
  File "density_functional_theory.py", line 1002, in _eigh
    eigvects, eigvals = ipu_eigh(x, sort_eigenvalues=True, num_iters=12)
  File "/nethome/blazejb/.venvs/3.2.0+1277/3.2.0+1277_poptorch/lib/python3.8/site-packages/tessellate_ipu/linalg/tile_linalg_jacobi.py", line 310, in ipu_eigh
    A, VT = ipu_jacobi_eigh(x, num_iters=num_iters)
  File "/nethome/blazejb/.venvs/3.2.0+1277/3.2.0+1277_poptorch/lib/python3.8/site-packages/tessellate_ipu/linalg/tile_linalg_jacobi.py", line 215, in ipu_jacobi_eigh
    Apcols, Aqcols, Vpcols, Vqcols = jax.lax.fori_loop(
  File "/nethome/blazejb/.venvs/3.2.0+1277/3.2.0+1277_poptorch/lib/python3.8/site-packages/tessellate_ipu/linalg/tile_linalg_jacobi.py", line 213, in <lambda>
    eigh_iteration_fn = lambda _, x: ipu_jacobi_eigh_iteration(x, Atiles, Vtiles)
  File "/nethome/blazejb/.venvs/3.2.0+1277/3.2.0+1277_poptorch/lib/python3.8/site-packages/tessellate_ipu/linalg/tile_linalg_jacobi.py", line 180, in ipu_jacobi_eigh_iteration
    Apcols, Aqcols = tile_rotate_columns(Apcols, Aqcols, rotset)
  File "/nethome/blazejb/.venvs/3.2.0+1277/3.2.0+1277_poptorch/lib/python3.8/site-packages/tessellate_ipu/linalg/tile_linalg_jacobi.py", line 276, in tile_rotate_columns
    assert len(pcols_indices_new) == halfN
AssertionError

ipu_eigh only works for d>=6 we fixed in nanoDFT here. Same fix should also work in density_functional_theory.py.

The fix is just a switch:

def eigh(x):
   if d <= 6: jnp.linalg.eigh(x)
   else: custom_ipu_eigh(x)

This PR fixes it. Feel free to accept/close this issue.

Thanks @blazejba for finding the issue! Just opened a bug ticket on TessellateIPU so we can solve it there: graphcore-research/tessellate-ipu#12
@AlexanderMath Thanks for the PR. Let's merge that for now until there is a fix in TessellateIPU.