Denys88/rl_games

How to use CNN in PPO

nuomizai opened this issue · 3 comments

Hi, does anyone know where I can find an example config file like rl_games/configs/ppo_continuous.yaml except that I want to use CNN to handle the image input? I tried to set the config file as follows:

config:
    name: ${resolve_default:FrankaCabinet,${....experiment}}
    full_experiment_name: ${.name}
    env_name: rlgpu
    ppo: True
    mixed_precision: False
    normalize_input: True
    normalize_value: True
    num_actors: ${....task.env.numEnvs}
    reward_shaper:
      scale_value: 0.01
    normalize_advantage: True
    gamma: 0.99
    tau: 0.95
    learning_rate: 5e-4
    lr_schedule: adaptive
    kl_threshold: 0.008
    score_to_win: 10000
    max_epochs: ${resolve_default:1500,${....max_iterations}}
    save_best_after: 200
    save_frequency: 100
    print_stats: True
    grad_norm: 1.0
    entropy_coef: 0.0
    truncate_grads: True
    e_clip: 0.2
    horizon_length: 16
    minibatch_size: 5
    mini_epochs: 8
    critic_coef: 4
    clip_value: True
    seq_len: 4
    bounds_loss_coef: 0.0001
    use_entral_value: True
    central_value_config:
      normalize_input: True
      learning_rate: 0.0005
      input_shape: [3, 320, 480]
      model:
        name: continuous_a2c_logstd

      network:
        name: resnet_actor_critic
        separate: False
        value_shape: 1
        space:
          discrete:

        cnn:
          conv_depths: [ 16, 32, 32 ]
          activation: relu
          initializer:
            name: default
          regularizer:
            name: 'None'

        mlp:
          units: [ 256, 128, 64 ]
          activation: elu
          d2rl: False

          initializer:
            name: default
          regularizer:
            name: None

I set use_entral_value to True and set central_value_config. But an error occurred as

Traceback (most recent call last):
  File "train.py", line 133, in launch_rlg_hydra
    'checkpoint': cfg.checkpoint
  File "/home/quan/rl_games/rl_games/torch_runner.py", line 109, in run
    self.run_train(args)
  File "/home/quan/rl_games/rl_games/torch_runner.py", line 88, in run_train
    agent = self.algo_factory.create(self.algo_name, base_name='run', params=self.params)
  File "/home/quan/rl_games/rl_games/common/object_factory.py", line 15, in create
    return builder(**kwargs)
  File "/home/quan/rl_games/rl_games/torch_runner.py", line 38, in <lambda>
    self.algo_factory.register_builder('a2c_continuous', lambda **kwargs : a2c_continuous.A2CAgent(**kwargs))
  File "/home/quan/rl_games/rl_games/algos_torch/a2c_continuous.py", line 59, in __init__
    self.central_value_net = central_value.CentralValueTrain(**cv_config).to(self.ppo_device)
  File "/home/quan/rl_games/rl_games/algos_torch/central_value.py", line 37, in __init__
    self.model = network.build(state_config)
  File "/home/quan/rl_games/rl_games/algos_torch/models.py", line 28, in build
    return self.Network(self.network_builder.build(self.model_class, **config), obs_shape=obs_shape,
  File "/home/quan/rl_games/rl_games/algos_torch/network_builder.py", line 766, in build
    net = A2CResnetBuilder.Network(self.params, **kwargs)
  File "/home/quan/rl_games/rl_games/algos_torch/network_builder.py", line 599, in __init__
    NetworkBuilder.BaseNetwork.__init__(self, **kwargs)
  File "/home/quan/rl_games/rl_games/algos_torch/network_builder.py", line 35, in __init__
    nn.Module.__init__(self, **kwargs)
TypeError: __init__() got an unexpected keyword argument 'num_agents'

So, is there any config file that I can refer to?

Hey, I'll take a look later today.
There might be a bug with resnet_actor_critic.
Could you try regular simple CNN too?

Could you try master branch too, I've fixed some issues a few weeks ago.
Also you dont need to use input_shape: [3, 320, 480].
It automatically takes shape form the state space.

Could you try master branch too, I've fixed some issues a few weeks ago. Also you dont need to use input_shape: [3, 320, 480]. It automatically takes shape form the state space.

Thank you, @Denys88 . I've tried your latest version and remove the input_shape paramter. Now it works well!