jax-ml/jax

AOT compilation and serialization

Closed this issue · 13 comments

Description

Hi, I have a numerical solver that I takes a long time to compile. I want to compile it ahead of time, export, serialize, and then store it. And then I want to load, rehydrate and run it without recompilation. Is this possible? To my knowledge the jax.export functionality still requires a compile post rehydration. Also when I try to do it with more complex versions of the solver I start getting segfaults. Is a compilation cache the way to go here? Been having trouble with that as well.

Currently I have a solver object that calls self.compile() in it's init. Here is a simpler version of that compile function:

def compile(self) -> None:
        """Precompile solve function"""
        dir = os.path.dirname(os.path.abspath(__file__))
        compiled_dir = os.path.join(dir, "../compiled")
        os.makedirs(compiled_dir, exist_ok=True)

        file_name = f"filename.pkl"
        relative_path = os.path.join(compiled_dir, file_name)

        if os.path.exists(relative_path):
            with open(relative_path, "rb") as f:
                serialized = pickle.load(f)
                self._solve_compiled = export.deserialize(serialized).call
            print("Found existing compiled ilqr solver, loaded.")
            return

        print("Existing compiled solver not found, compiling...")
        
        placeholder_1= jax.ShapeDtypeStruct((self.nx,), jnp.dtype("float32"))
        ...etc
        

        placeholders = (placeholder_1,...)

        start = time.time()
        jitted_solve = jax.jit(self._solve_internal)
        compiled_solve = jitted_solve.lower(*placeholders).compile()
        self._solve_compiled = compiled_solve
        print(f"compiled in {time.time()-start}s!")

        start = time.time()
        exported = export.export(jitted_solve)(*placeholders)
        serialized = exported.serialize()
        with open(relative_path, "wb") as f:
            pickle.dump(serialized, f)
        print(f"Exported and stored in {time.time()-start}s!")

The idea behind this is that if the compiled version of this class exists, use that. Otherwise make one and store it for future instantiations of the solver with the same desired input/output shapes.

Not sure if #476 should've been closed if you still have to recompile on load.

System info (python version, jaxlib version, accelerator, etc.)

Python Version: 3.10.15
Operating System: Linux 5.15.167.4-microsoft-standard-WSL2 (posix)
JAX Version: 0.4.35
jaxlib Version: 0.4.35
Default Backend: cpu

Was not aware that existed thanks!!! Altered the code to this:

from jax.experimental.serialize_executable import serialize as serialize_compiled

    def compile(self) -> None:
        """Precompile solve function"""
        dir = os.path.dirname(os.path.abspath(__file__))
        compiled_dir = os.path.join(dir, "../compiled")
        os.makedirs(compiled_dir, exist_ok=True)

        file_name = f"filename.pkl"
        relative_path = os.path.join(compiled_dir, file_name)

        if os.path.exists(relative_path):
            with open(relative_path, "rb") as f:
                serialized, in_tree, out_tree = pickle.load(f)
                self._solve_compiled = deserialize_and_load(serialized, in_tree, out_tree)
            print("Found existing compiled ilqr solver, loaded.")
            return

        print("Existing compiled solver not found, compiling...")

        placeholder_1= jax.ShapeDtypeStruct((self.nx,), jnp.dtype("float32"))
        ...etc

        placeholders = (placeholder_1,...)

        start = time.time()
        jitted_solve = jax.jit(self._solve_internal)
        compiled_solve = jitted_solve.lower(*placeholders).compile()
        self._solve_compiled = compiled_solve
        print(f"compiled in {time.time()-start}s!")

        # start = time.time()
        # exported = export.export(jitted_solve)(*placeholders)
        # serialized = exported.serialize()
        # print(f"Exported and stored in {time.time()-start}s!")

        serialized = serialize_compiled(compiled_solve)
        with open(relative_path, "wb") as f:
            pickle.dump(serialized, f)

but getting this error on deserialize_and_load

File "/opt/project/cosyco/.venv/lib/python3.10/site-packages/jax/experimental/serialize_executable.py", line 53, in deserialize_and_load
no_kwargs) = _JaxPjrtUnpickler(io.BytesIO(serialized), backend).load()
File "/opt/project/cosyco/.venv/lib/python3.10/site-packages/jax/experimental/serialize_executable.py", line 87, in persistent_load
return self.backend.deserialize_executable(pid[1])
jaxlib.xla_extension.XlaRuntimeError: INTERNAL: Failed to find compiled symbol for kernel bitcast_pad_fusion.6

I feel like I'm using the api wrong, any ideas?

I don't think you're using the api wrong. This looks more like a bug in XLA deserialization. Maybe there is some part of your computation that XLA is not serializing and deserializing properly? One way to double-check this is to see if you get the same failure with the compilation-cache enabled. The compilation cache requires retracing so it might not be what you want, but if you see the same problem there, then it is a problem with how XLA is serializing executables.

ope it was a mistake on my end. the script i was running it from had these lines:
jax.config.update("jax_persistent_cache_min_entry_size_bytes", -1)
jax.config.update("jax_persistent_cache_min_compile_time_secs", 0)
cc.set_cache_dir("/tmp/jax_cache")

but when i commented them out it worked! Thanks so much. I've been banging my head against the wall all day.

Follow-on: There still seems to be a bit more overhead left though, as when I run solve multiple times the first one is a bit slower. The compile takes 5 seconds though so I know it's not recompiling (at least from scratch).

Found existing compiled solver, loaded.
Solver created!
After 11 iterations, Linesearch succeeded with mu=0.0
Jax took 0.4983346462249756 and 0.04530315101146698s/iteration
After 11 iterations, Linesearch succeeded with mu=0.0
Jax took 0.024268627166748047 and 0.00220623891800642s/iteration
After 11 iterations, Linesearch succeeded with mu=0.0
Jax took 0.023716211318969727 and 0.0021560192108154297s/iteration
After 11 iterations, Linesearch succeeded with mu=0.0
Jax took 0.02395915985107422 and 0.0021781055256724358s/iteration
After 11 iterations, Linesearch succeeded with mu=0.0
Jax took 0.022493600845336914 and 0.002044872846454382s/iteration
After 11 iterations, Linesearch succeeded with mu=0.0
Jax took 0.024571895599365234 and 0.0022338086273521185s/iteration

also when I use a more complex version of the solver I still get seg faults.

A segfault would probably need some sort of reproducer in order to fix. This codepath doesn't call anything special though over just calling jit directly so it would have to be serializing inside XLA somewhere.

As I work on finding the minimum version of the segfault error, I was wondering if the previous problem is possible to be solved and not just inherent to jax. Jitting, lowering, and pre compiling eliminates the precompute time to my knowledge, but as shown in the previous comment, there is still some overhead on the first run of the solver. When i run the solver with the same inputs again, there is a 20x speed up. Is this normal or am I doing something wrong?

Not clear if you switched to using compilation cache or got the aot code to work, but the speedup probably comes from the fact that you have to actually load the executable onto the devices when dispatching the first time. (Also for GPU, there is probably some lazy cuda initialization). These costs are fundamental to either approach (compilation cache or aot).

Thanks for the quick response. Yes, to clarify, I got the AOT code to work and this is CPU.

I don't think I understand "you have to actually load the executable on to the devices when dispatching the first time". Is there a way to to do this AOT without knowing the specific inputs? I.e. sacrifice the first run with placeholders and then the following run with actual inputs will enjoy the 20x speedup? This is a realtime robotics application so the speed up is fairly important.

If you just want to hot-start this, you can just pass dummy values (using the avals of the inputs) and then block_until_ready on the results. use jax.stages.Compiled.args_info to get the necessary shapes of the inputs.

I see, so either way I have to sacrifice some time for dummy variables after loading in the compiled function, even after lowering with placeholders. Fair enough. I will do some testing and see what I can do. Thanks for your help.

One more thing, the block_until_ready is just for accurate timing correct?

the segfault was a result of jax.numpy.linalg.inv. I found this issue #11321 and used the inv3x3 and all the sudden no more segfaults. Strange.