Suggestion: alias `Transformed`(WithState) apply to __call__
hylkedonker opened this issue · 2 comments
Context:
I find my self often abusing decorators on Haiku functions like this:
import haiku as hk
@hk.without_apply_rng
@hk.transform
def foo(x):
lin = hk.Linear(output_size=2)
return lin(x)
Which converts foo
to a hk.Transformed
container. Then later, I use foo by calling foo.apply(params, x)
.
Suggestion
Link Transformed.apply
to Transformed.__call__
(and likewise for TransformedWithState
).
Pros:
- The
hk.transform
decorated function both looks and "behaves" like a function:foo(params, x)
. - If initialisation is not needed, additional decoration may be added on top, like
jax.vmap
.
Cons:
- The signature of the resulting callable is not transparent. That is, it differs from the original function and depends on which decorators were applied.
Does that sound like something that could be useful? I would love to hear your thoughts!
Thanks,
Hylke
Hey @hylkedonker, the con you've pointed to is the main reason we have not gone for this historically, namely that it is confusing to have a decorator that changes the signature of the function.
Haiku (like Optax, Stax and other libraries) produces a pair of functions with different signatures to the input function. You can actually treat this object as a pair if you would like:
def f(x):
..
f_init, f_apply = hk.transform(f)
And if you prefer to continue to reference the name f
you can actually override that in the unpacking:
def f(x):
..
f_init, f = hk.transform(f)
params = f_init(..)
out = f(params, ..)
Again we don't suggest this in our guides since it is quite confusing for new users that the signature of f
appears to have changed from the function definition.
If you really prefer stacking decorators, then you can monkey patch Haiku in place to support this functionality:
hk.Transformed.__call__ = property(fget=lambda self: self.apply)
hk.TransformedWithState.__call__ = property(fget=lambda self: self.apply)
Here is a worked example: https://colab.research.google.com/gist/tomhennigan/d9957a2233604e116d8f4ba4d9c9e3bb/add-__call__-to-transformed-withstate.ipynb
Thanks for clarifying!