MichaelTMatthews/Craftax

Natural language state descriptions

Closed this issue · 4 comments

Hi!

Does craftax currently provide natural language descriptions of the states? Or captions?

If so, how can these be accessed?

Thanks!

Edit: I was able to render the textual descriptions:

import jax
from craftax_classic.envs.craftax_symbolic_env import CraftaxClassicSymbolicEnv
from craftax.envs.craftax_symbolic_env import CraftaxSymbolicEnv
from environment_base.wrappers import LogWrapper
from craftax.renderer import render_craftax_text

rng = jax.random.PRNGKey(0)
rng, _rng = jax.random.split(rng)
rngs = jax.random.split(_rng, 3)

# Create environment
env = CraftaxSymbolicEnv()
env_params = env.default_params

# Get an initial state and observation
obs, state = env.reset(rngs[0], env_params)

# Pick random action
action = env.action_space(env_params).sample(rngs[1])

# Step environment
obs, state, reward, done, info = env.step(rngs[2], state, action, env_params)

# Get the text representation of the state
text_state = render_craftax_text(state)

but I am not sure this works if using batched environments and/or the LogWrapper. Could you please confirm?

Hi Roger,

The current render_craftax_state operates on EnvState. If using the LogWrapper then you'll be passing around the LogEnvState which wraps EnvState. To retrieve the wrapped EnvState call state.env_state. I don't think there should be any issues with batched environments.

Let me know if this helps and if you're still having issues please post the steps to reproduce and I'll take a look.

Hi! Thanks for replying :)

This code works fine:

import jax
from environment_base.wrappers import BatchEnvWrapper
from craftax.envs.craftax_symbolic_env import CraftaxSymbolicEnv
from craftax.renderer import render_craftax_text

rng = jax.random.PRNGKey(0)
rng, _rng = jax.random.split(rng)
rngs = jax.random.split(_rng, 3)

env = CraftaxSymbolicEnv()
env_params = env.default_params

obs, state = env.reset(rngs[0], env_params)

print(
    render_craftax_text(state)
)

and this one doesn't:

import jax
from environment_base.wrappers import BatchEnvWrapper
from craftax.envs.craftax_symbolic_env import CraftaxSymbolicEnv
from craftax.renderer import render_craftax_text

rng = jax.random.PRNGKey(0)
rng, _rng = jax.random.split(rng)
rngs = jax.random.split(_rng, 3)

env = BatchEnvWrapper(
    CraftaxSymbolicEnv(),
    num_envs=32
)
env_params = env.default_params

obs, state = env.reset(rngs[0], env_params)

print(
    render_craftax_text(state)
)

since all the attributes of state have an additional dimension (e.g. state.player_position.shape is now (32,2) instead of (2,))

How can we get the text observation for each env in this case?

Ty!

Ah, I see the issue. So if render_craftax was JITtable then it would be as simple as

text_representations = jax.vmap(render_craftax)(state)

However, this isn't the case right now (although it would be possible with a bit of wrangling).
So to extract an individual state for now we have to use tree map like so:

for i in range(32):
    state_i = jax.tree_map(
        lambda x: x[i], state
    )

    print(render_craftax_text(state_i))

Let me know if this fixes the problem.

P.S. Testing this code revealed a bug that I've now fixed on main, so you'll want to pull the latest commit otherwise this will error.

Ah ! so that is what tree_map() does! :)

This solved it, thank you!