graphcore-research/tessellate-ipu

`ipu_eigh` OOM for 500x500 matrix.

AlexanderMath opened this issue · 4 comments

Reproducer

import jax 
import jax.numpy as jnp
import numpy as np 

def linalg_eigh(x):
  from tessellate_ipu.linalg import ipu_eigh
  eigvects, eigvals = ipu_eigh(x, sort_eigenvalues=True, num_iters=12)
  return eigvals, eigvects

a = np.random.normal(0, 1, (500, 500))

print(jax.jit(linalg_eigh, backend="ipu")(a))

Popvision profile
image

image

Found issue. Consider for N=500 the lines 142-144

for _ in range(1, N):
    rotset_sorted = jacobi_sort_rotation_set(rotset)
    print(rotset.shape, rotset.nbytes) # (250, 2) 2000
    rotset_replicated = tile_constant_replicated(rotset_sorted, tiles=Atiles)
    ...
    # Next rotation set.
    rotset = jacobi_next_rotation_set(rotset)

Each rotset takes up 2kb and we create N=500 different ones taking up 500*2kb=1MB. Since we replicate these as constants over tiles=Atiles we try to put 1MB on all the tiles => OOM.

Potential fix: Compute rotset on the fly, looks like it's ok for rotset to be computed on the fly.

Popvision profile verification for N=100. On tile 64 we have 100 constants each of 400 bytes (profile numbers are zero index'ed so 99 is 100).
image

Found a fix, attached profile for N=512 below. Pushed code to this branch. Let me know if it makes sense to write a PR.

Changes. Adding modifications to ipu_jacobi_eigh. Rotset was compiled to constant which added 1MB to certain tiles causing OOM (each tile had N=512 copies of rotset, so total of 1472*512~750k copies). Changing so rotset is computed on the fly using jax.numpy (=> each tile only has at most 2 rotsets at any time). Included changing static_gather to all_cols[all_indices, :] which (guessing) gets compiled to all_cols.T[:, all_indices].T. Finally, compile time blew up for N=512 so changed for i in range(1, N) to jax.lax.fori_loop(1, N, iteration, ..). Numerical error for d=64 is np.max(vals)=1.14e-05 and np.max(vects)=1.479 but 1e-4 for N=512 even with num_iters=64 (can't check if this is also the case for prev implementation due to OOM when N=512).

Potential tough questions before PR:

  1. do we want to retain control over unroll vs jax.lax.fori_loop?
  2. profiling w/ popvision took ~10min for N=512.
  3. do we want flexibility to switch between np and jnp when computing rotset (i.e. do at trace time or on-the-fly); can be done by adding argument def ipu_eigh(..., np_backend=jax.numpy).

image

Closing as #42 PR is fixing the issue, using a proper fori_loop.