Automatically transform in-place array assignments using autograph
dime10 opened this issue · 4 comments
Context
Catalyst uses a source-to-source transformation package called AutoGraph, which allows users to write regular Python code that is automagically transformed into JAX-style traceable code. For example, the following for loop can automatically be compiled by Catalyst to execute at run-time using the autograph
option, whereas we would otherwise need to explicitly use the functional form:
from catalyst import *
@qjit(autograph=True)
def f(n: int):
for i in range(n):
debug.print(i)
f(5)
We would like to extend the current AutoGraph support to apply to array updates as is typically done with Numpy arrays:
import numpy as np
x = np.zeros(10)
y = np.array([3, 2, 5])
# update some elements of x with y
x[4:7] = y
Attempting to do the same with JAX arrays will raise the following error:
TypeError: '<class 'jaxlib.xla_extension.ArrayImpl'>' object does not support item assignment. JAX arrays are immutable. Instead of ``x[idx] = y``, use ``x = x.at[idx].set(y)`` or another .at[] method: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html
Instead, JAX requires that a new array be created with the .at
style notation (which can optionally be reassigned to the old array):
import jax.numpy as jnp
x = jnp.zeros(10)
y = jnp.array([3, 2, 5])
# update some elements of x with y
x = x.at[4:7].set(y)
Goal
We would like the array updates to be automatically converted to the required JAX style when the autograph
flag is set to True. The following example demonstrates this feature:
@qjit(autograph=True)
def expand_by_two(x):
"""Expand the first dimension of x to twice the size and broadcast the contents."""
first_dim = x.shape[0]
result = jnp.empty((first_dim*2, *x.shape[1:]), dtype=x.dtype)
for i in range(2):
start = i * first_dim
stop = start + first_dim
result[start:stop] = x
return result
>>> expand_by_two(jnp.array([5, 3, 4]))
Array([5, 3, 4, 5, 3, 4], dtype=int64)
Requirements:
- the feature only needs to work with regular assignments,
+=
,*=
, etc. are not required - if the assigned-to object is not a JAX array, the implementation should fall back to the regular Python operator
- the AutoGraph guide should be updated
Technical details
- AutoGraph works by identifying specific patterns in the AST of a given Python function and replacing them with calls to pre-determined operator. By providing our own implementation of these operators we can change the behaviour of existing code, including of Python built-ins.
- The main object that manages the AutoGraph transformations in Catalyst is the CFTransformer
- Implementations for overloaded operators in Catalyst are provided in the ag_primitives.py module.
- AutoGraph has an overloadable
set_item
operator (example implementation from TensorFlow) which requires the List Feature. This operator should allow us to implement the desired functionality.
When working on AutoGraph features it is particularly important to pay attention to edge cases and include comprehensive test cases.
Note that this issue does not require installing Catalyst from source. The PyPI wheel can simply be downloaded and installed/extracted, and the Python files modified in-place. The optional dependency on tensorflow
needs to be separately installed.
I've encountered some difficulty with implementing this feature as described in the issue.
I'm including all of the details of what I've attempted as well as resolutions to each problem I encountered (except for Problem 3). If you would like I can additionally commit and push my (messy) work to a fork; so far I have been editing the source files in the wheel as you suggested.
Test case used
I converted your example case into a test case, which I tested against in frontend/test/pytest/test_autograph.py
.
def test_array_set(self):
@qjit(autograph=True)
def fn(x):
"""Expand the first dimension of x to twice the size and broadcast the contents."""
first_dim = x.shape[0]
result = jnp.empty((first_dim*2, *x.shape[1:]), dtype=x.dtype)
for i in range(2):
start = i * first_dim
stop = start + first_dim
result[start:stop] = x
return result
check_cache(fn.original_function)
assert fn(jnp.array([5, 3, 4])) == jnp.array([5, 3, 4, 5, 3, 4])
Problem 1: Wrong converter?
I stubbed out an implementation of set_item
as suggested in the Technical details.
I noticed that this function seemed to not get invoked in my testing.
It seems like there was a typo in the description and the actual converter is the slices converter.
If this is not the case, then I think I am rather confused. Let me know what I can do to make my confusion clear (e.g. pushing my code to a fork).
Problem 2: get_item
If I use the slices converter, then it also converts
first_dim = x.shape[0]
into
first_dim = ag__.get_item(x.shape, 0, opts=ag__.GetItemOpts(element_dtype=None))
which is fine by me because I can make an implementation of get_item
too.
However, I'm not too sure how to pass the opts
around. It seems tensorflow uses to propagate information about the type of the element (see this function), but I did not investigate it in careful detail because get_item
on the whole seems to be secondary to this issue.
I'm inclined to simplify the test example you provided (when I use it for my own tests) to avoid dealing with the opts
argument correctly, but if you know of an obvious way to handle this I'm happy to implement get_item
too.
Problem 3: set_item
Even when I simplify the code to something like
@qjit(autograph=True)
def fn(x, y):
x[:] = y
return x
I still see the error you describe:
TypeError: '<class 'jax._src.interpreters.partial_eval.DynamicJaxprTracer'>' object does not support item assignment. JAX arrays are immutable. Instead of ``x[idx] = y``, use ``x = x.at[idx].set(y)`` or another .at[] method: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html
The (relevant) part of the translated code appears to still be
x[:] = y
I'm confused because I would expect this to have been translated to
ag__set_item(x, SomeIndex(), y)
since I am using
node = slices.transform(node, ctx)
to transform the AST (per Problem 1).
I additionally tried lists.transform
just in case I missed something, but that didn't seem to do anything.
I also thought it was an issue with ordering of transformations, so I tried putting the slices transformation at the very beginning in the CFTransformer
, but that also didn't do anything.
I don't have a solution to this problem: do you have any ideas on why this translation might not be happening? I've tried manually tracing the tensorflow implementation of slices.transform
to no avail.
I would prefer to have submitted a PR for this, but I've spent more than a couple hours across tracing things and referring to documentation already. Were I being paid to work on this issue, I probably would have reached out a little earlier.
Hi @cole-k,
Great job digging into the problem! Regarding your questions:
Problem 1: Wrong converter?
You're right, the link given in the problem description points to the wrong transformer, the relevant one is the SliceTransformer
. The reason I mentioned the "Slice Feature" is because that is what enables the SliceTransformer
in the TensorFlow implementation of AutoGraph:
https://github.com/tensorflow/tensorflow/blob/6b3ea7502770f9921e6ddb0a7c6b1d9f292498f6/tensorflow/python/autograph/impl/api.py#L249-L251
This could have been a bit clearer, apologies for that!
Problem 2: get_item
Exactly, we can just provide a default Python implementation of get_item
since we don't want to assign any special functionality to this operator. With that in mind, I don't think we care about the value of opts
, since the original operands of the operator are just passed in as is (in your example we get x.shapes
and 0
, which is sufficient to reconstruct x.shapes[0]
).
Problem 3: set_item
Are you saying your implementation of ag__.set_item
is not being invoked? Have you confirmed with this with a debugger or some other means (e.g. printing)?
The line you are using to run the transformer seems correct, although without seeing the surrounding code I can't be 100% certain. How are you looking at the translated code?
Problem 3 explained
Problem 3: set_item
Are you saying your implementation of ag__.set_item is not being invoked? Have you confirmed with this with a debugger or some other means (e.g. printing)?
The line you are using to run the transformer seems correct, although without seeing the surrounding code I can't be 100% certain. How are you looking at the translated code?
It seems like the issue stems from the fact that the subscripts we are using are themselves not translated.
I will refer to the source code of the slice translator in my explanation:
L35 s = target.slice
L36 if isinstance(s, (gast.Tuple, gast.Slice)):
# if None is returned, target is not translated
L37 return None
We'll examine what goes wrong with the line
result[start:stop] = x
Step 1: parsing
The assignment appears to be parsed from
result[start:stop] = x
to
Assign(
[Subscript(Name('result', _ctx), Slice(Name('start'), Name('stop'), None), _ctx)],
Name('x', _ctx),
'',
)
Nothing wrong here.
Step 2: autograph translation
The slice translator works on the Assign
if there's a single target and that target is aSubscript
. It looks at its slice
and doesn't change anything if the slice is an instance of Slice
or Tuple
(this behavior is the same for the code that converts to get_item
).
So the source line is not transformed, which is the root cause of my troubles.
Aside
This explains why I had the issue initially with get_item
: it's because the first line was subscripting a Constant
, which is neither (x.shape[0]
). Other subscripts, like x.shape[1:]
, are left untouched.
Step 3: evaluation
So even though I provide an implementation of ag__set_item(...)
, there is no such occurrence in the translated code and it therefore errors in the way I'm describing.
A hacky patch
If we change the code to
result[jnp.s_[start:stop]] = x
i.e., use s_ (which is neither a tuple nor subscript at parse time), then we won't encounter this issue because the slice translator won't automatically reject it.
While this isn't my favorite solution, it's the best thing I could come up with without subclassing the slice
translator and explicitly changing how it translates (which I would rather not do given all the work I've put in so far).
Hmm, so the SliceTransformer
doesn't support slicing?
A bit strange, but that's alright. We can just restrict a first implementation of this feature to only support single indices :)
Thanks for the investigation!