How to use CNN in PPO
nuomizai opened this issue · 3 comments
Hi, does anyone know where I can find an example config file like rl_games/configs/ppo_continuous.yaml
except that I want to use CNN to handle the image input? I tried to set the config file as follows:
config:
name: ${resolve_default:FrankaCabinet,${....experiment}}
full_experiment_name: ${.name}
env_name: rlgpu
ppo: True
mixed_precision: False
normalize_input: True
normalize_value: True
num_actors: ${....task.env.numEnvs}
reward_shaper:
scale_value: 0.01
normalize_advantage: True
gamma: 0.99
tau: 0.95
learning_rate: 5e-4
lr_schedule: adaptive
kl_threshold: 0.008
score_to_win: 10000
max_epochs: ${resolve_default:1500,${....max_iterations}}
save_best_after: 200
save_frequency: 100
print_stats: True
grad_norm: 1.0
entropy_coef: 0.0
truncate_grads: True
e_clip: 0.2
horizon_length: 16
minibatch_size: 5
mini_epochs: 8
critic_coef: 4
clip_value: True
seq_len: 4
bounds_loss_coef: 0.0001
use_entral_value: True
central_value_config:
normalize_input: True
learning_rate: 0.0005
input_shape: [3, 320, 480]
model:
name: continuous_a2c_logstd
network:
name: resnet_actor_critic
separate: False
value_shape: 1
space:
discrete:
cnn:
conv_depths: [ 16, 32, 32 ]
activation: relu
initializer:
name: default
regularizer:
name: 'None'
mlp:
units: [ 256, 128, 64 ]
activation: elu
d2rl: False
initializer:
name: default
regularizer:
name: None
I set use_entral_value to True and set central_value_config. But an error occurred as
Traceback (most recent call last):
File "train.py", line 133, in launch_rlg_hydra
'checkpoint': cfg.checkpoint
File "/home/quan/rl_games/rl_games/torch_runner.py", line 109, in run
self.run_train(args)
File "/home/quan/rl_games/rl_games/torch_runner.py", line 88, in run_train
agent = self.algo_factory.create(self.algo_name, base_name='run', params=self.params)
File "/home/quan/rl_games/rl_games/common/object_factory.py", line 15, in create
return builder(**kwargs)
File "/home/quan/rl_games/rl_games/torch_runner.py", line 38, in <lambda>
self.algo_factory.register_builder('a2c_continuous', lambda **kwargs : a2c_continuous.A2CAgent(**kwargs))
File "/home/quan/rl_games/rl_games/algos_torch/a2c_continuous.py", line 59, in __init__
self.central_value_net = central_value.CentralValueTrain(**cv_config).to(self.ppo_device)
File "/home/quan/rl_games/rl_games/algos_torch/central_value.py", line 37, in __init__
self.model = network.build(state_config)
File "/home/quan/rl_games/rl_games/algos_torch/models.py", line 28, in build
return self.Network(self.network_builder.build(self.model_class, **config), obs_shape=obs_shape,
File "/home/quan/rl_games/rl_games/algos_torch/network_builder.py", line 766, in build
net = A2CResnetBuilder.Network(self.params, **kwargs)
File "/home/quan/rl_games/rl_games/algos_torch/network_builder.py", line 599, in __init__
NetworkBuilder.BaseNetwork.__init__(self, **kwargs)
File "/home/quan/rl_games/rl_games/algos_torch/network_builder.py", line 35, in __init__
nn.Module.__init__(self, **kwargs)
TypeError: __init__() got an unexpected keyword argument 'num_agents'
So, is there any config file that I can refer to?
Hey, I'll take a look later today.
There might be a bug with resnet_actor_critic.
Could you try regular simple CNN too?
Could you try master branch too, I've fixed some issues a few weeks ago.
Also you dont need to use input_shape: [3, 320, 480].
It automatically takes shape form the state space.
Could you try master branch too, I've fixed some issues a few weeks ago. Also you dont need to use input_shape: [3, 320, 480]. It automatically takes shape form the state space.
Thank you, @Denys88 . I've tried your latest version and remove the input_shape paramter. Now it works well!