juliuskunze/jaxnet

Repair unbatched PixelCNN example

Closed this issue · 0 comments

The unbatched PixelCNN example fails with

Exception: Can't lift Traced<ShapedArray(float32[32,32,1]):JaxprTrace(level=6/0)> to JaxprTrace(level=5/0)

while initializing parameters in down_shifted_conv during trace_to_jaxpr. An older version of the network was working on an unbatched input in jaxnet==0.1.4 and jax==0.1.41, so this is probably a bug in the tracing logic of jaxnet.core.