patrick-kidger/jaxtyping

Is it possible for method annotation of an eqx.Module to pick up instance field annotations?

jjyyxx opened this issue · 7 comments

Currently, unless I am missing something, for an eqx.Module, the instance field annotations and instance method annotations are disjoint. This prevents the module input from being sufficiently checked (see 1 in snippet). In other words, field annotations only catch bugs during module initialization but are unhelpful when the module is invoked by the caller. While I could come up with several workarounds (see 2 and 3 in snippet), they come at the cost of increased complexity.

import equinox as eqx
import jax

from jaxtyping import Array, Float, Key, jaxtyped
from beartype import beartype

@jaxtyped(typechecker=beartype)
class Linear(eqx.Module):
    in_size: int = eqx.field(static=True)
    out_size: int = eqx.field(static=True)
    weight: Float[Array, "out in"]
    bias: Float[Array, "out"]

    def __init__(self, in_size: int, out_size: int, *, key: Key[Array, ""]):
        self.in_size = in_size
        self.out_size = out_size

        wkey, bkey = jax.random.split(key)
        self.weight = jax.random.normal(wkey, (out_size, in_size))
        self.bias = jax.random.normal(bkey, (out_size,))

    # 1: Ideal, but not working
    @jaxtyped(typechecker=beartype)
    def __call__(self, x: Float[Array, "in"]) -> Float[Array, "out"]:
        return self.weight @ x + self.bias

    # 2: Works, but verbose and inconsistent
    @jaxtyped(typechecker=beartype)
    def __call__(self, x: Float[Array, "{self.in_size}"]) -> Float[Array, "{self.out_size}"]:
        return self.weight @ x + self.bias

    # 3: Works, but duplicates weight and bias annotation
    def __call__(self, x: Float[Array, "in"]) -> Float[Array, "out"]:
        return self._impl(self.weight, self.bias, x)
    @staticmethod
    @jaxtyped(typechecker=beartype)
    def _impl(weight: Float[Array, "out in"], bias: Float[Array, "out"], x: Float[Array, "in"]) -> Float[Array, "out"]:
        return weight @ x + bias

model = Linear(2, 3, key=jax.random.key(0))
x = jax.numpy.zeros((2,))
model(x)

I went through the issues and did not find relevant discussions. What is the suggested way to address this? Can this be implemented in jaxtyping?

What do you mean by option 2 being 'inconsistent'? Agreed it's a little verbose but should accomplish what you're after.

I mean that in this case, the field annotations (e.g., in/out) and the method annotations (e.g., self.in_size/self.out_size) refer to the same concept but use very different naming styles.

class Linear(eqx.Module):
    weight: Float[Array, "out in"]
    bias: Float[Array, "out"]
    # ...
    @jaxtyped(typechecker=beartype)
    def __call__(self, x: Float[Array, "{self.in_size}"]) -> Float[Array, "{self.out_size}"]:
        return self.weight @ x + self.bias

Of course, I can change weight and bias annotations to weight: Float[Array, "{self.out_size} {self.in_size}"] and bias: Float[Array, "{self.out_size}"], but this would become even more verbose.

Additionally, it can sometimes be challenging to directly associate the dimensions with hyperparameters, which may lead to something like this:

    def __call__(self, x: Float[Array, "{self.weight.shape[1]}"]) -> Float[Array, "{self.weight.shape[0]}"]:
        return self.weight @ x + self.bias

I appreciate the design of jaxtyping because it strikes a nice balance between static type checking, runtime type checking, and documentation readability. If annotations have to be written like Float[Array, "{self.weight.shape[1]}"] to allow for runtime type checking, it can significantly reduce readability.

Ah, I see what you're getting at. Indeed, something like bias: Float[Array, "{self.out_size}"] would then be the consistent + fully-checked way to describe things.

Unfortunately I don't have a better solution to this problem; if you have a suggestion then I'm happy to hear it.

FWIW this doesn't worry me too much. There are many things we could, in principle, put into the type system -- e.g. is this integer positive? -- that we typically don't. Type systems are always about picking a trade-off between verbosity and validation, so I think we're in the usual situation in this regard. :)

@patrick-kidger I went through the jaxtyping internals and found that implementing the feature I requested wasn't too difficult. I've created a prototype and would appreciate your feedback on it. While it generally works, the design, interface, naming, internals, and edge case handling still need refinement.

Regarding the semantics, since a class’s suite (including attributes and parameter annotations in method declarations) shares the same dedicated local namespace, I believe it’s intuitive to extend these semantics to the dim_str of attributes and method parameters. I’m not suggesting a breaking change in behavior, but I think it’s a reasonable feature to support behind an option.

diff --git a/jaxtyping/_decorator.py b/jaxtyping/_decorator.py
index 2d0eac7..1883b6c 100644
--- a/jaxtyping/_decorator.py
+++ b/jaxtyping/_decorator.py
@@ -30,7 +30,7 @@ from jaxtyping import AbstractArray
 
 from ._config import config
 from ._errors import AnnotationError, TypeCheckError
-from ._storage import pop_shape_memo, push_shape_memo, shape_str
+from ._storage import get_shape_memo, pop_shape_memo, push_shape_memo, shape_str
 
 
 class _Sentinel:
@@ -52,7 +52,7 @@ def jaxtyped(fn, *, typechecker=_sentinel):
     ...
 
 
-def jaxtyped(fn=_sentinel, *, typechecker=_sentinel):
+def jaxtyped(fn=_sentinel, *, typechecker=_sentinel, contextual=False, _keep_shape_memo=False):
     """Decorate a function with this to perform runtime type-checking of its arguments
     and return value. Decorate a dataclass to perform type-checking of its attributes.
 
@@ -252,7 +252,7 @@ def jaxtyped(fn=_sentinel, *, typechecker=_sentinel):
         typechecker = None
 
     if fn is _sentinel:
-        return ft.partial(jaxtyped, typechecker=typechecker)
+        return ft.partial(jaxtyped, typechecker=typechecker, contextual=contextual)
     elif inspect.isclass(fn):
         if dataclasses.is_dataclass(fn) and typechecker is not None:
             # This does not check that the arguments passed to `__init__` match the
@@ -260,6 +260,9 @@ def jaxtyped(fn=_sentinel, *, typechecker=_sentinel):
             # dataclass-generated `__init__` used alongside
             # `equinox.field(converter=...)`
 
+            assert not contextual
+            fn.__jaxtyped__ = True
+
             init = fn.__init__
 
             @ft.wraps(init)
@@ -280,7 +283,7 @@ def jaxtyped(fn=_sentinel, *, typechecker=_sentinel):
                 # metaclass `__call__`, because Python doesn't allow you
                 # monkey-patch metaclasses.
                 if self.__class__.__init__ is fn.__init__:
-                    _check_dataclass_annotations(self, typechecker)
+                    _check_dataclass_annotations(self, typechecker, keep_shape_memo=False)
 
             fn.__init__ = __init__
         return fn
@@ -515,13 +518,30 @@ def jaxtyped(fn=_sentinel, *, typechecker=_sentinel):
                 bound = param_signature.bind(*args, **kwargs)
                 bound.apply_defaults()
 
-                memos = push_shape_memo(bound.arguments)
+                if contextual:
+                    assert len(args) > 0
+                    self = args[0]
+                    assert getattr(self, "__jaxtyped__", False)
+                    try:
+                        _check_dataclass_annotations(self, typechecker, keep_shape_memo=True)
+                    except:
+                        # 1. dataclass not initialized through init
+                        # 2. field modified after initialization
+                        pop_shape_memo()
+                        raise
+                    memos = get_shape_memo()
+                    *_, arguments = memos
+                    arguments.clear()  # Comment to merge
+                    arguments.update(bound.arguments)
+                else:
+                    memos = push_shape_memo(bound.arguments)
                 try:
                     # Put this in a separate frame to make debugging easier, without
                     # just always ending up on the `pop_shape_memo` line below.
                     return wrapped_fn_impl(args, kwargs, bound, memos)
                 finally:
-                    pop_shape_memo()
+                    if not _keep_shape_memo:
+                        pop_shape_memo()
 
         return wrapped_fn
 
@@ -534,7 +554,7 @@ class _JaxtypingContext:
         pop_shape_memo()
 
 
-def _check_dataclass_annotations(self, typechecker):
+def _check_dataclass_annotations(self, typechecker, keep_shape_memo):
     """Creates and calls a function that checks the attributes of `self`
 
     `self` should be a dataclass instance. `typechecker` should be e.g.
@@ -578,7 +598,7 @@ def _check_dataclass_annotations(self, typechecker):
         signature,
         output=False,
     )
-    f = jaxtyped(f, typechecker=typechecker)
+    f = jaxtyped(f, typechecker=typechecker, contextual=False, _keep_shape_memo=keep_shape_memo)
     f(self, **values)

This patch allows the following snippet to work correctly:

import equinox as eqx
import jax

from jaxtyping import Array, Float, Key, jaxtyped, config
config.update("jaxtyping_remove_typechecker_stack", True)
from beartype import beartype

@jaxtyped(typechecker=beartype)
class Linear(eqx.Module):
    in_size: int = eqx.field(static=True)
    out_size: int = eqx.field(static=True)
    weight: Float[Array, "O I"]
    bias: Float[Array, "O"]

    def __init__(self, in_size: int, out_size: int, *, key: Key[Array, ""]):
        self.in_size = in_size
        self.out_size = out_size

        wkey, bkey = jax.random.split(key)
        self.weight = jax.random.normal(wkey, (out_size, in_size))
        self.bias = jax.random.normal(bkey, (out_size,))

    @jaxtyped(typechecker=beartype, contextual=True)
    def __call__(self, x: Float[Array, "I"]) -> Float[Array, "O"]:
        return self.weight @ x + self.bias

model = Linear(2, 3, key=jax.random.key(0))
x = jax.numpy.zeros((2,))
model(x)

Okay, I quite like this approach! This seems like a nice improvement.

I think my main goal would be to remove the two extra arguments to jaxtyped:

  • Can we somehow detect that we're decorating a dataclass method, and avoid the need for contextual=True?
  • I think we can avoid _keep_shape_memo by adjusting _check_dataclass_annotation to actually return its function-to-be-checked. Then in its current call site we can immediately wrap it in jaxtyped and call it. In the new call site we can check it as part of the current context.

I'm aware that the former is technically a (very minor) breaking change, that could result in extra errors. We can bump the version number to reflect this; I'd rather do that then have an extra flag.

Can we somehow detect that we're decorating a dataclass method, and avoid the need for contextual=True?

Detecting this at the time of class definition seems non-trivial, but it becomes much easier when the dataclass method is called, by checking via self. However, is_dataclass(self) alone may not be sufficient; we might also need to verify if the dataclass is decorated with jaxtyped (e.g., by setting a __jaxtyped__ attribute).

I think we can avoid _keep_shape_memo by adjusting _check_dataclass_annotation to actually return its function-to-be-checked. Then in its current call site we can immediately wrap it in jaxtyped and call it. In the new call site we can check it as part of the current context.

I was aware of this approach, but when I was developing the feature, I was uncertain about how to ensure the context correctness of both functions, especially concerning the arguments memo. I believe some modifications related to context manipulation may still be necessary.

In addition, several design choices need to be made:

  1. Since you prefer to control the behavior globally via a flag, should we allow users to opt in/out of this feature locally? And should this feature apply only to jaxtyped dataclasses or to any dataclass?
  2. Currently, single_memo, variadic_memo, and pytree_memo are shared/reused, but the arguments memo is not. I chose not to share the arguments to avoid shadowing and unexpected behavior; users can always access them via self if needed.

Agreed that checking at decoration time sounds hard. Checking at runtime seems doable, however:

import functools as ft
import inspect
import types


def _is_method(fn, args, kwargs):
    if isinstance(fn, types.FunctionType):
        if len(args) > 0:
            first_arg = args[0]
            cls_dict = getattr(type(first_arg), "__dict__", {})
            is_method = fn in cls_dict.values()
        else:
            parameters = inspect.signature(fn).parameters
            if len(parameters) > 0:
                first_arg_name = next(iter(parameters))
                try:
                    first_arg = kwargs[first_arg_name]
                except KeyError:
                    # Presumably we're about to get a type error from a missing
                    # argument. Don't do anything here, let the normal function call
                    # raise the error.
                    return
                else:
                    cls_dict = getattr(type(first_arg), "__dict__", {})
                    is_method = fn in cls_dict.values()
            else:
                if len(kwargs) == 0:
                    is_method = False
                else:
                    # Likewise, presumably a type error.
                    return
    else:
        # Included this branch just in case we have `jaxtyped` wrapping other
        # objects, which may not be hashable -- which is required for the
        # `fn in cls_dict` check above.
        is_method = False
    print(f"{fn} is method: {is_method}")


def is_method(fn):
    @ft.wraps(fn)
    def wrapper(*args, **kwargs):
        _is_method(wrapper, args, kwargs)
        return fn(*args, **kwargs)
    return wrapper


@is_method
def global_function():
    pass

class A:
    @is_method
    def method_function(self):
        pass


def wrapper():
    @is_method
    def closure_function():
        pass
    return closure_function

global_function()
A().method_function()
wrapper()()

No strong feelings on only applying to jaxtyped dataclasses vs all dataclasses. I think it's rare for there to be unchecked dataclasses that nonetheless use jaxtyping annotations, after all. I would like to avoid setting __jaxtyped__ flags, which is the kind of thing that is easily error-prone.

Not including arguments makes sense to me.