Repair unbatched PixelCNN example
Closed this issue · 0 comments
juliuskunze commented
The unbatched PixelCNN example fails with
Exception: Can't lift Traced<ShapedArray(float32[32,32,1]):JaxprTrace(level=6/0)> to JaxprTrace(level=5/0)
while initializing parameters in down_shifted_conv
during trace_to_jaxpr
. An older version of the network was working on an unbatched input in jaxnet==0.1.4
and jax==0.1.41
, so this is probably a bug in the tracing logic of jaxnet.core
.