Wrapping the ```init``` function inside ```jax.jit```
ksmdnl opened this issue · 1 comments
I'm currently doing a runtime analysis of the attention matrix of a transformer. Specifically, I'd like to know how the time complexity behaves w.r.t. to the size of the attention matrix.
def model(x):
net = factory(hidden_dim, num_layers=1)
return net(*x)
def inference(fn, rng, params, x, mode="runtime"):
start = time.timeit()
_ = jax.block_until_ready(fn)(params, x)
end = time.timeit()
print(f"{mode}: {end - start} s")
return end - start
def main():
if args.single == 0:
nb_nodes = np.arange(args.max_nb_node) + 1
else:
nb_nodes = [args.max_nb_node]
runtimes = []
rng = jax.random.PRNGKey(42)
for nb_node in nb_nodes:
print(f"Number of node: {nb_node}")
net = hk.without_apply_rng(hk.transform(model))
node_fts = jax.random.normal(rng, (batch_size, nb_node, hidden_dim))
edge_fts = jax.random.normal(rng, (batch_size, nb_node, nb_node, hidden_dim))
x = (node_fts, edge_fts)
params = net.init(rng, x)
apply_fn = jax.jit(net.apply)
# compile time
_ = inference(apply_fn, rng, params, x, mode="compile time")
# execution time
runtime = inference(apply_fn, rng, params, x, mode="execution time")
runtimes.append(runtime)
For nb_node = 414
I'm getting an OOM error in the initialization (when performing the Einstein summation), which looks as follows:
jaxlib.xla_extension.XlaRuntimeError: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 54101757696 bytes.
BufferAssignment OOM Debugging.
BufferAssignment stats:
parameter allocation: 249.86MiB
constant allocation: 0B
maybe_live_out allocation: 50.39GiB
preallocated temp allocation: 0B
total allocation: 50.63GiB
total fragmentation: 0B (0.00%)
This is quite strange, since I'm using A100 80GB.
However, when one jit
the init funtion, there is no OOM error even after nb_node = 500
.
My question is, would this be a correct workaround given in this case?
tl;dr - You should jit the init
function to get a version of it that (1) uses as little memory as possible (2) runs quickly.
To add a bit more detail, jax.jit
(through XLA) applies a number of optimizations to your program. Some of these might reduce the overall peak memory footprint required by the init
program. For example, one optimization that XLA does is limiting the live range of arrays.
Lets consider the following JAX program:
def fn():
a = some_big_array()
b = other_big_array()
c = a + b
d = yet_another_big_array()
e = c + d
return e
XLA would be able to notice that a
and b
can be safely deleted (and as such their GPU memory woudl be freed) before you compute d
:
def fn():
a = some_big_array()
b = other_big_array()
c = a + b
# XLA knows a/b aren't used again so it can release GPU memory for them.
free_gpu_memory(a)
free_gpu_memory(b)
d = yet_another_big_array()
e = c + d
return e
If you wanted to debug further and figure out which arrays were still hanging around causing the 50GB allocation to fail, then JAX allows you to see a traceback for where arrays were created, which might help you understand this in more depth. To help you get started, something like the following might work:
import traceback
def print_live_arrays():
for array in jax.live_arrays():
print(array.shape, array.dtype)
traceback.print_tb(array.traceback.as_python_traceback())
print()
try:
benchmark()
except RuntimeError as e:
if 'RESOURCE_EXHAUSTED' in e:
print_live_arrays()
raise e
That said, even if you knew the root cause (e.g. which arrays were hanging around) the recommended fix is the same: use jax.jit
and let XLA optimize this for you.