opendilab/LightZero

env problem

Closed this issue · 1 comments

in the env config as well as the policy config, I set the observation_shape to (3 64 64) as in the example code. however, I find that the image return from the env manager is (1 64 64). how can I set it to 3 channels.

env=dict(
            stop_value=int(1e6),
            env_name=env_name,
            obs_shape=(3, 64, 64),
            collector_env_num=collector_env_num,
            evaluator_env_num=evaluator_env_num,
            n_evaluator_episode=evaluator_env_num,
            manager=dict(shared_memory=False, ),
        ),
        policy=dict(
            model=dict(
                observation_shape=(3, 64, 64),
                frame_stack_num=3,
                action_space_size=action_space_size,
                downsample=True,
                self_supervised_learning_loss=True,  # default is False
                discrete_action_encoding_type='one_hot',
                norm_type='BN',
  • Hello, whether the environment's output observation is a grayscale image is controlled by this parameter: gray_scale. To adjust the gray_scale parameter and modify observation_shape and image_channel in model config accordingly for obtaining RGB observations instead of grayscale, you can modify the configuration as follows:
atari_muzero_config = dict(
    exp_name=f'data_mz_ctree/{env_id[:-14]}_muzero_ns{num_simulations}_upc{update_per_collect}_rr{reanalyze_ratio}_seed0',
    env=dict(
        ...
        gray_scale=False,
        observation_shape=(12, 96, 96),
        ...
    ),
    policy=dict(
        model=dict(
            observation_shape=(12, 96, 96),
            frame_stack_num=4,
            image_channel=3,
            action_space_size=action_space_size,
            ...
        ),
        ...
  • Ensure that after changing gray_scale to False, you also correspondingly update the value of observation_shape and image_channel to reflect the three color channels (typically RGB). Here, observation_shape=(12, 96, 96) signifies that there are 4 stacked frames, each frame being 96x96 pixels, and each frame has 3 image_channel.