PennyLaneAI/catalyst

Automatically transform in-place array assignments using autograph

dime10 opened this issue · 4 comments

Context

Catalyst uses a source-to-source transformation package called AutoGraph, which allows users to write regular Python code that is automagically transformed into JAX-style traceable code. For example, the following for loop can automatically be compiled by Catalyst to execute at run-time using the autograph option, whereas we would otherwise need to explicitly use the functional form:

from catalyst import *

@qjit(autograph=True)
def f(n: int):
    for i in range(n):
        debug.print(i)

f(5)

We would like to extend the current AutoGraph support to apply to array updates as is typically done with Numpy arrays:

import numpy as np

x = np.zeros(10)
y = np.array([3, 2, 5])

# update some elements of x with y
x[4:7] = y

Attempting to do the same with JAX arrays will raise the following error:

TypeError: '<class 'jaxlib.xla_extension.ArrayImpl'>' object does not support item assignment. JAX arrays are immutable. Instead of ``x[idx] = y``, use ``x = x.at[idx].set(y)`` or another .at[] method: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html

Instead, JAX requires that a new array be created with the .at style notation (which can optionally be reassigned to the old array):

import jax.numpy as jnp

x = jnp.zeros(10)
y = jnp.array([3, 2, 5])

# update some elements of x with y
x = x.at[4:7].set(y)

Goal

We would like the array updates to be automatically converted to the required JAX style when the autograph flag is set to True. The following example demonstrates this feature:

@qjit(autograph=True)
def expand_by_two(x):
    """Expand the first dimension of x to twice the size and broadcast the contents."""
    
    first_dim = x.shape[0]
    result = jnp.empty((first_dim*2, *x.shape[1:]), dtype=x.dtype)

    for i in range(2):
        start = i * first_dim
        stop = start + first_dim
        result[start:stop] = x

    return result

>>> expand_by_two(jnp.array([5, 3, 4]))
Array([5, 3, 4, 5, 3, 4], dtype=int64)

Requirements:

  • the feature only needs to work with regular assignments, +=, *=, etc. are not required
  • if the assigned-to object is not a JAX array, the implementation should fall back to the regular Python operator
  • the AutoGraph guide should be updated

Technical details

  • AutoGraph works by identifying specific patterns in the AST of a given Python function and replacing them with calls to pre-determined operator. By providing our own implementation of these operators we can change the behaviour of existing code, including of Python built-ins.
  • The main object that manages the AutoGraph transformations in Catalyst is the CFTransformer
  • Implementations for overloaded operators in Catalyst are provided in the ag_primitives.py module.
  • AutoGraph has an overloadable set_item operator (example implementation from TensorFlow) which requires the List Feature. This operator should allow us to implement the desired functionality.

When working on AutoGraph features it is particularly important to pay attention to edge cases and include comprehensive test cases.

Note that this issue does not require installing Catalyst from source. The PyPI wheel can simply be downloaded and installed/extracted, and the Python files modified in-place. The optional dependency on tensorflow needs to be separately installed.

I've encountered some difficulty with implementing this feature as described in the issue.

I'm including all of the details of what I've attempted as well as resolutions to each problem I encountered (except for Problem 3). If you would like I can additionally commit and push my (messy) work to a fork; so far I have been editing the source files in the wheel as you suggested.

Test case used

I converted your example case into a test case, which I tested against in frontend/test/pytest/test_autograph.py.

    def test_array_set(self):
        @qjit(autograph=True)
        def fn(x):
            """Expand the first dimension of x to twice the size and broadcast the contents."""

            first_dim = x.shape[0]
            result = jnp.empty((first_dim*2, *x.shape[1:]), dtype=x.dtype)

            for i in range(2):
                start = i * first_dim
                stop = start + first_dim
                result[start:stop] = x

            return result

        check_cache(fn.original_function)
        assert fn(jnp.array([5, 3, 4])) == jnp.array([5, 3, 4, 5, 3, 4])

Problem 1: Wrong converter?

I stubbed out an implementation of set_item as suggested in the Technical details.

I noticed that this function seemed to not get invoked in my testing.

It seems like there was a typo in the description and the actual converter is the slices converter.

If this is not the case, then I think I am rather confused. Let me know what I can do to make my confusion clear (e.g. pushing my code to a fork).

Problem 2: get_item

If I use the slices converter, then it also converts

first_dim = x.shape[0]

into

first_dim = ag__.get_item(x.shape, 0, opts=ag__.GetItemOpts(element_dtype=None))

which is fine by me because I can make an implementation of get_item too.

However, I'm not too sure how to pass the opts around. It seems tensorflow uses to propagate information about the type of the element (see this function), but I did not investigate it in careful detail because get_item on the whole seems to be secondary to this issue.

I'm inclined to simplify the test example you provided (when I use it for my own tests) to avoid dealing with the opts argument correctly, but if you know of an obvious way to handle this I'm happy to implement get_item too.

Problem 3: set_item

Even when I simplify the code to something like

@qjit(autograph=True)
def fn(x, y):
  x[:] = y
  return x

I still see the error you describe:

TypeError: '<class 'jax._src.interpreters.partial_eval.DynamicJaxprTracer'>' object does not support item assignment. JAX arrays are immutable. Instead of ``x[idx] = y``, use ``x = x.at[idx].set(y)`` or another .at[] method: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html

The (relevant) part of the translated code appears to still be

x[:] = y

I'm confused because I would expect this to have been translated to

ag__set_item(x, SomeIndex(), y)

since I am using

node = slices.transform(node, ctx)

to transform the AST (per Problem 1).

I additionally tried lists.transform just in case I missed something, but that didn't seem to do anything.

I also thought it was an issue with ordering of transformations, so I tried putting the slices transformation at the very beginning in the CFTransformer, but that also didn't do anything.

I don't have a solution to this problem: do you have any ideas on why this translation might not be happening? I've tried manually tracing the tensorflow implementation of slices.transform to no avail.


I would prefer to have submitted a PR for this, but I've spent more than a couple hours across tracing things and referring to documentation already. Were I being paid to work on this issue, I probably would have reached out a little earlier.

Hi @cole-k,

Great job digging into the problem! Regarding your questions:

Problem 1: Wrong converter?

You're right, the link given in the problem description points to the wrong transformer, the relevant one is the SliceTransformer. The reason I mentioned the "Slice Feature" is because that is what enables the SliceTransformer in the TensorFlow implementation of AutoGraph:
https://github.com/tensorflow/tensorflow/blob/6b3ea7502770f9921e6ddb0a7c6b1d9f292498f6/tensorflow/python/autograph/impl/api.py#L249-L251
This could have been a bit clearer, apologies for that!

Problem 2: get_item

Exactly, we can just provide a default Python implementation of get_item since we don't want to assign any special functionality to this operator. With that in mind, I don't think we care about the value of opts, since the original operands of the operator are just passed in as is (in your example we get x.shapes and 0, which is sufficient to reconstruct x.shapes[0]).

Problem 3: set_item

Are you saying your implementation of ag__.set_item is not being invoked? Have you confirmed with this with a debugger or some other means (e.g. printing)?
The line you are using to run the transformer seems correct, although without seeing the surrounding code I can't be 100% certain. How are you looking at the translated code?

Problem 3 explained

Problem 3: set_item

Are you saying your implementation of ag__.set_item is not being invoked? Have you confirmed with this with a debugger or some other means (e.g. printing)?
The line you are using to run the transformer seems correct, although without seeing the surrounding code I can't be 100% certain. How are you looking at the translated code?

It seems like the issue stems from the fact that the subscripts we are using are themselves not translated.

I will refer to the source code of the slice translator in my explanation:

L35    s = target.slice
L36    if isinstance(s, (gast.Tuple, gast.Slice)):
# if None is returned, target is not translated
L37      return None

We'll examine what goes wrong with the line

result[start:stop] = x

Step 1: parsing

The assignment appears to be parsed from

result[start:stop] = x

to

Assign(
  [Subscript(Name('result', _ctx), Slice(Name('start'), Name('stop'), None), _ctx)],
  Name('x', _ctx),
  '',
)

Nothing wrong here.

Step 2: autograph translation

The slice translator works on the Assign if there's a single target and that target is aSubscript. It looks at its slice and doesn't change anything if the slice is an instance of Slice or Tuple (this behavior is the same for the code that converts to get_item).

So the source line is not transformed, which is the root cause of my troubles.

Aside

This explains why I had the issue initially with get_item: it's because the first line was subscripting a Constant, which is neither (x.shape[0]). Other subscripts, like x.shape[1:], are left untouched.

Step 3: evaluation

So even though I provide an implementation of ag__set_item(...), there is no such occurrence in the translated code and it therefore errors in the way I'm describing.

A hacky patch

If we change the code to

result[jnp.s_[start:stop]] = x

i.e., use s_ (which is neither a tuple nor subscript at parse time), then we won't encounter this issue because the slice translator won't automatically reject it.

While this isn't my favorite solution, it's the best thing I could come up with without subclassing the slice translator and explicitly changing how it translates (which I would rather not do given all the work I've put in so far).

Hmm, so the SliceTransformer doesn't support slicing?

A bit strange, but that's alright. We can just restrict a first implementation of this feature to only support single indices :)

Thanks for the investigation!