env problem
Closed this issue · 1 comments
spring520 commented
in the env config as well as the policy config, I set the observation_shape to (3 64 64) as in the example code. however, I find that the image return from the env manager is (1 64 64). how can I set it to 3 channels.
env=dict(
stop_value=int(1e6),
env_name=env_name,
obs_shape=(3, 64, 64),
collector_env_num=collector_env_num,
evaluator_env_num=evaluator_env_num,
n_evaluator_episode=evaluator_env_num,
manager=dict(shared_memory=False, ),
),
policy=dict(
model=dict(
observation_shape=(3, 64, 64),
frame_stack_num=3,
action_space_size=action_space_size,
downsample=True,
self_supervised_learning_loss=True, # default is False
discrete_action_encoding_type='one_hot',
norm_type='BN',
puyuan1996 commented
- Hello, whether the environment's output observation is a grayscale image is controlled by this parameter: gray_scale. To adjust the gray_scale parameter and modify
observation_shape
andimage_channel
inmodel config
accordingly for obtaining RGB observations instead of grayscale, you can modify the configuration as follows:
atari_muzero_config = dict(
exp_name=f'data_mz_ctree/{env_id[:-14]}_muzero_ns{num_simulations}_upc{update_per_collect}_rr{reanalyze_ratio}_seed0',
env=dict(
...
gray_scale=False,
observation_shape=(12, 96, 96),
...
),
policy=dict(
model=dict(
observation_shape=(12, 96, 96),
frame_stack_num=4,
image_channel=3,
action_space_size=action_space_size,
...
),
...
- Ensure that after changing
gray_scale
to False, you also correspondingly update the value ofobservation_shape
andimage_channel
to reflect the three color channels (typically RGB). Here,observation_shape
=(12, 96, 96) signifies that there are 4 stacked frames, each frame being 96x96 pixels, and each frame has 3image_channel
.