MichaelTMatthews/Craftax

`play_craftax` performance

Closed this issue · 3 comments

When I play_craftax, I noticed the following two performance problems:

  1. The first environment step is really slow. I guess it's due to JAX tracing + compilation. I wonder if compilation is necessary for interactive mode. If human experts play at ~5 steps per second for a single environment (author speed cited in craftax paper), is it necessary to use jit?

  2. Subsequent steps are fast, but the game runs very intense on my laptop (as measured by CPU usage, laptop temperature, battery drain, etc.) making it uncomfortable to play for more than a couple of minutes. Since there is not a lot of processing going on between turns, I think this should be considered a bug.

    I think the cause is as follows. I was looking at the play_craftax game loop. I noticed that this loop polls for input and renders the state in a loop that does not appear to have any rate limiting mechanism. Therefore it runs my processor as fast as it can, repeatedly re-rendering the state, even though nothing is changing. Two potential solutions are (1) inserting a fraction-of-a-second time.sleep at the end of the loop, or, better yet, if available (2) using a pygame primitive that blocks the process waiting for input rather than polling for inputs in get_action_from_keypress.

I know that craftax is an RL research environment, not primarily a video game designed for human players. But these issues seem potentially easy to fix and could make a big difference in the experience for researchers orienting towards the environment (such as myself).

Regarding 1:

I found the --debug command line option for play_craftax which disables jit. The result is that each step takes a few seconds. I see now why you included jit by default.

I was not expecting such a performance drop without jit! So now I am confused. For my own curiosity, would be interested in any insights into why the steps take so long without jit... it seems like there should not be that much processing required to evolve and render the state of a single copy of the environment?

Thanks for bringing this to my attention - I have mostly been running on a beefy desktop so hadn't realised quite how bad the performance was on these scripts.

I've pushed a fix that

  1. Only renders a new frame when a button is pressed, otherwise the old frame is simply redrawn
  2. Rate limits the play script to a configurable (default 60) FPS

I'm seeing significantly better performance from this now (CPU down from 90% to 20%)

As for your question about why un-jitted code is so slow, I would say that the code is never meant to be run outside of JIT so it will be horribly optimised for this. For instance, consider the JAX in-place update x = x.at[y].set(z), which occurs in some form all over the code (especially in the rendering). Outside of a JIT this actually makes a copy of x, sets index y to value z and returns this new copy (leaving the old x to be garbage collected). This is clearly very inefficient compared to a simple in-place update - which this line will be compiled to inside a JIT. There are many other similar instances where the 'JAX way' to do something doesn't make sense if the code is taken literally and not compiled.

I hope this answers your question and let me know if the play script runs better for you now.

Yes, that answers my question, and the play scripts run much better for me now. Thanks!