chex.Dimensions API enhancement
Opened this issue · 1 comments
I would like to propose an API enhancement that allow the use of chex.Dimensions inside function annotations. If there is interest I'd like to contribute. Example below:
dims = chex.Dimensions(B=batch_size, T=sequence_len, E=embedding_dim)
...
def foo(arr: chex.Array):
chex.assert_shape(arr, dims['BTE'])
# fn logic
### turns into ###
def foo(arr: chex.Array(dims['BTE'])): # behind the scenes assert on function call
# fn logic
This is particularly useful for dataclasses e.g.
dims = chex.Dimensions(B=batch_size, T=rollout_len)
# asserts are run on instantiation
class TimeStep:
q_values: chex.Array(dims['BT'])
discounts: chex.Array(dims['BT'])
rewards: chex.Array(dims['BT'])
Pros:
- reduces clutter that asserts can add
- allows user to view the shape expected by function or class in editor (not sure what you call the VScode popup)
- example: using RLax, in order to know what shape is expected for each arg in a loss fn you need to either look at source code or wait for fn call to raise an assert
Cons:
- increased API complexity
- ...?
Thanks for your interest in chex!
This suggestion is very interesting. Many of us working with arrays in python on a daily basis are eagerly awaiting PEP 646, which was accepted into python version 3.11.
Once python 3.11 becomes more mainstream we will definitely consider incorporating shape annotations into chex. And perhaps we could augment or fork chex.Dimensions to return TypeVarTuples, along the lines of your suggestion.
For the time being, however, we will not implement such a change. In particular, mixing runtime checks with static type annotation is out of scope, at least for now.
P.S. If you're interested in doing type annotation at runtime, you might find the pydantic project useful: https://docs.pydantic.dev/