google/tree-math

Transform for defining dataclasses with VectorMixin like flax.struct

shoyer opened this issue · 2 comments

It would be nice to have an easy way to define dataclasses that are also tree-math vectors.

We could borrow the syntax of flax.struct here: https://flax.readthedocs.io/en/latest/flax.struct.html

Example usage:

from tree_math import struct

@struct
class FluidState:
  velocity_x: Array
  velocity_y: Array
  pressure: Array

CC @jamieas who started on work on this

shoyer commented

This completed. See "Custom vector classes" in the README for usage instructions.