Consider supporting static attributes in chex.dataclass
NeilGirdhar opened this issue · 1 comments
NeilGirdhar commented
from jax import jit
from jax.lax import scan
from tjax import IntegralNumeric, RealNumeric
from tjax.dataclasses import dataclass, field
import chex
def f(carry, _):
return carry + 1.0, None
@jit
def do_scan(c):
final, _ = scan(f, c.x, None, c.y)
return final
@dataclass
class C:
x: RealNumeric
y: IntegralNumeric = field(static=True)
print(do_scan(C(1.0, 10))) # works
@chex.dataclass
class D:
x: RealNumeric
y: IntegralNumeric
print(do_scan(D(x=1.0, y=10))) # fails
NeilGirdhar commented
I guess this is a duplicate of #64!