jax-ml/jax

TracerBoolConversionError

Closed this issue · 4 comments

Description

I am trying to create Mandelbrot Set using Jax and after transforming the function using vmap i got TracerBoolConversionError
error message

code:

import jax
from jax import numpy as jnp
from jax import jit,vmap
def mandle_brot(c,max_iter):
    z= jnp.array([0.0+0.0j])
    for i in range(max_iter):
        z=z*z+c
        if(jnp.abs(z)>2.0):
            return jnp.array([i])
    return jnp.array([max_iter])
z = jnp.complex_(-0.75 + 0.1j)
mandle_brot(z,100)
Z= jnp.array([z,z,z,z])
mandle_brot_vmap = vmap(mandle_brot,in_axes=(0,None))
mandle_brot_vmap(Z,10) # error from this part

{
"name": "TracerBoolConversionError",
"message": "Attempted boolean conversion of traced array with shape bool[1].
This BatchTracer with object id 140294572667376 was created on line:
/tmp/ipykernel_16031/267948302.py:5:11 (mandle_brot)
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerBoolConversionError",
"stack": "---------------------------------------------------------------------------
TracerBoolConversionError Traceback (most recent call last)
Cell In[123], line 1
----> 1 mandle_brot_vmap(Z,10)

[... skipping hidden 3 frame]

Cell In[91], line 5, in mandle_brot(c, max_iter)
3 for i in range(max_iter):
4 z=z*z+c
----> 5 if(jnp.abs(z)>2.0):
6 return jnp.array([i])
7 return jnp.array([max_iter])

[... skipping hidden 1 frame]

File ~/anaconda3/envs/ML/lib/python3.11/site-packages/jax/_src/core.py:1554, in concretization_function_error..error(self, arg)
1553 def error(self, arg):
-> 1554 raise TracerBoolConversionError(arg)

TracerBoolConversionError: Attempted boolean conversion of traced array with shape bool[1].
This BatchTracer with object id 140294572667376 was created on line:
/tmp/ipykernel_16031/267948302.py:5:11 (mandle_brot)
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerBoolConversionError"
}

Any help

System info (python version, jaxlib version, accelerator, etc.)

I am using Jax in my WSL.

jax: 0.4.35
jaxlib: 0.4.34
numpy: 1.26.4
python: 3.11.8 | packaged by conda-forge | (main, Feb 16 2024, 20:53:32) [GCC 12.3.0]
device info: NVIDIA GeForce RTX 4050 Laptop GPU-1, 1 local devices"
process_count: 1
platform: uname_result(system='Linux', node='fool', release='5.15.153.1-microsoft-standard-WSL2', version='#1 SMP Fri Mar 29 23:14:13 UTC 2024', machine='x86_64')

$ nvidia-smi
Wed Nov 6 22:20:50 2024
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 560.35.03 Driver Version: 561.09 CUDA Version: 12.6 |
|-----------------------------------------+------------------------+----------------------+
| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|=========================================+========================+======================|
| 0 NVIDIA GeForce RTX 4050 ... On | 00000000:01:00.0 Off | N/A |
| N/A 36C P8 1W / 75W | 4768MiB / 6141MiB | 0% Default |
| | | N/A |
+-----------------------------------------+------------------------+----------------------+

+-----------------------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=========================================================================================|
| 0 N/A N/A 16031 C /python3.11 N/A |
+-----------------------------------------------------------------------------------------+

As the error indicates, the python control flow of jnp.abs(z)>2.0 is problematic. You can use jax.lax.cond instead. See https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#python-control-flow-jit and https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.cond.html.

The problem is with this line:

        if(jnp.abs(z)>2.0):
            return jnp.array([i])

jnp.abs(z) is a traced value, and so its value cannot be known while the Python if statement is executing.

For background on this, I'd start with JAX Key Concepts, and then refer to JAX Sharp Bits: Control Flow. The solution here will likely be to use jax.lax.while_loop rather than a Python while loop.

Feel free to ask back here if you have any questions!

You can use jax.lax.cond instead.

This is true in many cases, but in this particular code cond will not work, because it cannot be used to break out of a Python for-loop.

You can use jax.lax.cond instead.

This is true in many cases, but in this particular code cond will not work, because it cannot be used to break out of a Python for-loop.

True, I was just thinking of something to the effect of

def mandle_brot(c, max_iter):
    iters = jnp.arange(1, max_iter + 1)
    def body(carry, i):
        z, count = carry
        z = z * z + c
        escaped = jnp.abs(z) > 2.0
        count = lax.cond(escaped, lambda _: jnp.minimum(count, i), lambda _: count, None)
        return (z, count), count
    init = (0.0 + 0.0j, max_iter)
    _, counts = lax.scan(body, init, iters)
    return jnp.min(counts)

A while loop does seem more natural tho (although if they are being vmapped over, does the while loop save anything since the ones that finish early in the vmap just have redundant computation after that point that is post selected out?).