Custom env with FrameStack wrapper causes invalid actions to be passed to `env.step`
capnspacehook opened this issue · 2 comments
capnspacehook commented
🤖 Custom Gym Environment
Describe the bug
When using gymnasium.wrappers.frame_stack.FrameStack
with a simple custom env, I get an exception when an action is being chosen in step
Code example
import itertools
from typing import Any, List, Tuple
import gymnasium as gym
import numpy as np
from gymnasium.spaces import Box, Discrete
from gymnasium.wrappers.frame_stack import FrameStack
from sbx import PPO
from stable_baselines3.common.env_checker import check_env
from stable_baselines3.common.vec_env import DummyVecEnv
class MyEnv(gym.Env):
def __init__(self) -> None:
self.actions, self.action_space = self.actionSpace()
self.observation_space = Box(0, 1, shape=(1,))
def step(self, action: Any) -> Tuple[Any, float, bool, bool, dict]:
chosenAction = self.actions[action]
return self.obs(), 0.0, False, False, {}
def reset(
self, *, seed: int | None = None, options: dict | None = None
) -> Tuple[Any, dict]:
super().reset(seed=seed, options=options)
return self.obs(), {}
def obs(self):
return np.array([0.5], dtype=np.float32)
def render(self) -> Any | List[Any] | None:
def actionSpace(self):
baseActions = [0, 1, 2, 3, 4]
totalActionsWithRepeats = list(itertools.permutations(baseActions, 2))
withoutRepeats = []
for combination in totalActionsWithRepeats:
reversedCombination = combination[::-1]
if reversedCombination not in withoutRepeats:
filteredActions = [[action] for action in baseActions] + withoutRepeats
return filteredActions, Discrete(len(filteredActions))
if __name__ == "__main__":
env = MyEnv()
env = FrameStack(env, 4)
env = DummyVecEnv([lambda: env])
algo = PPO("MlpPolicy", env)
Traceback (most recent call last):
File "/home/user/", line 61, in <module>
File "/home/user/jax-venv/lib/python3.10/site-packages/sbx/ppo/", line 315, in learn
return super().learn(
File "/home/user/jax-venv/lib/python3.10/site-packages/stable_baselines3/common/", line 259, in learn
continue_training = self.collect_rollouts(self.env, callback, self.rollout_buffer, n_rollout_steps=self.n_steps)
File "/home/user/jax-venv/lib/python3.10/site-packages/sbx/common/", line 152, in collect_rollouts
new_obs, rewards, dones, infos = env.step(clipped_actions)
File "/home/user/jax-venv/lib/python3.10/site-packages/stable_baselines3/common/vec_env/", line 197, in step
return self.step_wait()
File "/home/user/jax-venv/lib/python3.10/site-packages/stable_baselines3/common/vec_env/", line 58, in step_wait
obs, self.buf_rews[env_idx], terminated, truncated, self.buf_infos[env_idx] = self.envs[env_idx].step(
File "/home/user/jax-venv/lib/python3.10/site-packages/gymnasium/wrappers/", line 179, in step
observation, reward, terminated, truncated, info = self.env.step(action)
File "/home/user/", line 21, in step
chosenAction = self.actions[action]
TypeError: only integer scalar arrays can be converted to a scalar index
### System Info
- OS: Linux-6.5.6-76060506-generic-x86_64-with-glibc2.35 # 202310061235
169739694522.04~9283e32 SMP PREEMPT_DYNAMIC Sun O - Python: 3.10.12
- Stable-Baselines3: 2.1.0
- PyTorch: 2.1.0+cu121
- GPU Enabled: True
- GPU Model: Nvida RTX 3080ti
- Numpy: 1.26.1
- Cloudpickle: 3.0.0
- Gymnasium: 0.29.1
at the latest commit was installed using pip
: pip install git+
### Checklist
- I have read the documentation (required)
- I have checked that there is no similar issue in the repo (required)
- I have checked my env using the env checker (required)
- I have provided a minimal working example to reproduce the bug (required)
araffin commented
thanks for the bug report.
I guess the issue comes from a flatten layer which is not applied in SBX.
A quick fix is to use a VecFrameStack
instead (it stacks on the last axis instead of the first):
from stable_baselines3.common.vec_env import DummyVecEnv, VecFrameStack
vec_env = DummyVecEnv([lambda: env])
vec_env = VecFrameStack(vec_env, 4)
To reproduce with a even more minimal code:
from typing import Any, List, Tuple
import gymnasium as gym
from gymnasium.spaces import Box, Discrete
from sbx import PPO
class MyEnv(gym.Env):
def __init__(self) -> None:
self.observation_space = Box(0, 1, shape=(2, 1), dtype="float32")
self.action_space = Discrete(15)
def step(self, action: Any) -> Tuple[Any, float, bool, bool, dict]:
return self.observation_space.sample(), 0.0, False, False, {}
def reset(
self, *, seed: int | None = None, options: dict | None = None
) -> Tuple[Any, dict]:
super().reset(seed=seed, options=options)
return self.observation_space.sample(), {}
def render(self) -> Any | List[Any] | None:
PPO("MlpPolicy", MyEnv()).learn(total_timesteps=1000)