[Minor bug] error check in hk.transform_with_state
albertfgu opened this issue · 1 comments
albertfgu commented
On this line: https://github.com/deepmind/dm-haiku/blob/6f2769e8c8dd35b3fc0e66905c877debea7d525f/haiku/_src/transform.py#L441
the check should be if state is None
instead of if state
. The current version will trigger if a state is passed in that is empty (i.e. state={}
) which leads to a confusing incorrect error message with the wrong method signature
ValueError: Apply must be called with an RNG as the second argument, the required signature is: `apply(params, rng, *a, **k)`
tomhennigan commented
Thanks for the report @albertfgu , I'll get a fix in for this. Here is a minimal reproducer:
>>> f = hk.transform_with_state(lambda: None)
>>> f.apply({}, {}, 'a')
ValueError: Apply must be called with an RNG as the second argument, the required signature is: `apply(params, rng, *a, **k)`. The object was of type <class 'str'>: a