`ipu_eigh` OOM for 500x500 matrix.
AlexanderMath opened this issue · 4 comments
Reproducer
import jax
import jax.numpy as jnp
import numpy as np
def linalg_eigh(x):
from tessellate_ipu.linalg import ipu_eigh
eigvects, eigvals = ipu_eigh(x, sort_eigenvalues=True, num_iters=12)
return eigvals, eigvects
a = np.random.normal(0, 1, (500, 500))
print(jax.jit(linalg_eigh, backend="ipu")(a))
Found issue. Consider for N=500
the lines 142-144
for _ in range(1, N):
rotset_sorted = jacobi_sort_rotation_set(rotset)
print(rotset.shape, rotset.nbytes) # (250, 2) 2000
rotset_replicated = tile_constant_replicated(rotset_sorted, tiles=Atiles)
...
# Next rotation set.
rotset = jacobi_next_rotation_set(rotset)
Each rotset
takes up 2kb and we create N=500
different ones taking up 500*2kb=1MB. Since we replicate these as constants over tiles=Atiles
we try to put 1MB on all the tiles => OOM.
Potential fix: Compute rotset
on the fly, looks like it's ok for rotset
to be computed on the fly.
Found a fix, attached profile for N=512 below. Pushed code to this branch. Let me know if it makes sense to write a PR.
Changes. Adding modifications to ipu_jacobi_eigh. Rotset was compiled to constant which added 1MB to certain tiles causing OOM (each tile had N=512 copies of rotset, so total of 1472*512~750k copies). Changing so rotset is computed on the fly using jax.numpy (=> each tile only has at most 2 rotsets at any time). Included changing static_gather to all_cols[all_indices, :] which (guessing) gets compiled to all_cols.T[:, all_indices].T
. Finally, compile time blew up for N=512
so changed for i in range(1, N)
to jax.lax.fori_loop(1, N, iteration, ..)
. Numerical error for d=64 is np.max(vals)=1.14e-05 and np.max(vects)=1.479 but 1e-4 for N=512 even with num_iters=64 (can't check if this is also the case for prev implementation due to OOM when N=512
).
Potential tough questions before PR:
- do we want to retain control over unroll vs jax.lax.fori_loop?
- profiling w/ popvision took ~10min for N=512.
- do we want flexibility to switch between
np
andjnp
when computing rotset (i.e. do at trace time or on-the-fly); can be done by adding argumentdef ipu_eigh(..., np_backend=jax.numpy)
.