google-deepmind/chex

Consider supporting static attributes in chex.dataclass

NeilGirdhar opened this issue · 1 comments

from jax import jit
from jax.lax import scan
from tjax import IntegralNumeric, RealNumeric
from tjax.dataclasses import dataclass, field
import chex

def f(carry, _):
  return carry + 1.0, None

@jit
def do_scan(c):
  final, _ = scan(f, c.x, None, c.y)
  return final

@dataclass
class C:
  x: RealNumeric
  y: IntegralNumeric = field(static=True)

print(do_scan(C(1.0, 10)))  # works

@chex.dataclass
class D:
  x: RealNumeric
  y: IntegralNumeric

print(do_scan(D(x=1.0, y=10)))  # fails

I guess this is a duplicate of #64!