Custom env with FrameStack wrapper causes invalid actions to be passed to `env.step`
capnspacehook opened this issue · 2 comments
capnspacehook commented
🤖 Custom Gym Environment
Describe the bug
When using gymnasium.wrappers.frame_stack.FrameStack
with a simple custom env, I get an exception when an action is being chosen in step
.
Code example
import itertools
from typing import Any, List, Tuple
import gymnasium as gym
import numpy as np
from gymnasium.spaces import Box, Discrete
from gymnasium.wrappers.frame_stack import FrameStack
from sbx import PPO
from stable_baselines3.common.env_checker import check_env
from stable_baselines3.common.vec_env import DummyVecEnv
class MyEnv(gym.Env):
def __init__(self) -> None:
self.actions, self.action_space = self.actionSpace()
self.observation_space = Box(0, 1, shape=(1,))
super().__init__()
def step(self, action: Any) -> Tuple[Any, float, bool, bool, dict]:
chosenAction = self.actions[action]
return self.obs(), 0.0, False, False, {}
def reset(
self, *, seed: int | None = None, options: dict | None = None
) -> Tuple[Any, dict]:
super().reset(seed=seed, options=options)
return self.obs(), {}
def obs(self):
return np.array([0.5], dtype=np.float32)
def render(self) -> Any | List[Any] | None:
pass
def actionSpace(self):
baseActions = [0, 1, 2, 3, 4]
totalActionsWithRepeats = list(itertools.permutations(baseActions, 2))
withoutRepeats = []
for combination in totalActionsWithRepeats:
reversedCombination = combination[::-1]
if reversedCombination not in withoutRepeats:
withoutRepeats.append(combination)
filteredActions = [[action] for action in baseActions] + withoutRepeats
return filteredActions, Discrete(len(filteredActions))
if __name__ == "__main__":
env = MyEnv()
check_env(env)
env = FrameStack(env, 4)
env = DummyVecEnv([lambda: env])
algo = PPO("MlpPolicy", env)
algo.learn(total_timesteps=1000)
Traceback (most recent call last):
File "/home/user/sbx_ppo_repro.py", line 61, in <module>
algo.learn(total_timesteps=1000)
File "/home/user/jax-venv/lib/python3.10/site-packages/sbx/ppo/ppo.py", line 315, in learn
return super().learn(
File "/home/user/jax-venv/lib/python3.10/site-packages/stable_baselines3/common/on_policy_algorithm.py", line 259, in learn
continue_training = self.collect_rollouts(self.env, callback, self.rollout_buffer, n_rollout_steps=self.n_steps)
File "/home/user/jax-venv/lib/python3.10/site-packages/sbx/common/on_policy_algorithm.py", line 152, in collect_rollouts
new_obs, rewards, dones, infos = env.step(clipped_actions)
File "/home/user/jax-venv/lib/python3.10/site-packages/stable_baselines3/common/vec_env/base_vec_env.py", line 197, in step
return self.step_wait()
File "/home/user/jax-venv/lib/python3.10/site-packages/stable_baselines3/common/vec_env/dummy_vec_env.py", line 58, in step_wait
obs, self.buf_rews[env_idx], terminated, truncated, self.buf_infos[env_idx] = self.envs[env_idx].step(
File "/home/user/jax-venv/lib/python3.10/site-packages/gymnasium/wrappers/frame_stack.py", line 179, in step
observation, reward, terminated, truncated, info = self.env.step(action)
File "/home/user/sbx_ppo_repro.py", line 21, in step
chosenAction = self.actions[action]
TypeError: only integer scalar arrays can be converted to a scalar index
### System Info
- OS: Linux-6.5.6-76060506-generic-x86_64-with-glibc2.35 # 202310061235
169739694522.04~9283e32 SMP PREEMPT_DYNAMIC Sun O - Python: 3.10.12
- Stable-Baselines3: 2.1.0
- PyTorch: 2.1.0+cu121
- GPU Enabled: True
- GPU Model: Nvida RTX 3080ti
- Numpy: 1.26.1
- Cloudpickle: 3.0.0
- Gymnasium: 0.29.1
sbx
at the latest commit was installed using pip
: pip install git+https://github.com/araffin/sbx
### Checklist
- I have read the documentation (required)
- I have checked that there is no similar issue in the repo (required)
- I have checked my env using the env checker (required)
- I have provided a minimal working example to reproduce the bug (required)
araffin commented
Hello,
thanks for the bug report.
I guess the issue comes from a flatten layer which is not applied in SBX.
A quick fix is to use a VecFrameStack
instead (it stacks on the last axis instead of the first):
from stable_baselines3.common.vec_env import DummyVecEnv, VecFrameStack
vec_env = DummyVecEnv([lambda: env])
vec_env = VecFrameStack(vec_env, 4)
To reproduce with a even more minimal code:
from typing import Any, List, Tuple
import gymnasium as gym
from gymnasium.spaces import Box, Discrete
from sbx import PPO
class MyEnv(gym.Env):
def __init__(self) -> None:
self.observation_space = Box(0, 1, shape=(2, 1), dtype="float32")
self.action_space = Discrete(15)
def step(self, action: Any) -> Tuple[Any, float, bool, bool, dict]:
return self.observation_space.sample(), 0.0, False, False, {}
def reset(
self, *, seed: int | None = None, options: dict | None = None
) -> Tuple[Any, dict]:
super().reset(seed=seed, options=options)
return self.observation_space.sample(), {}
def render(self) -> Any | List[Any] | None:
pass
PPO("MlpPolicy", MyEnv()).learn(total_timesteps=1000)