Denys88/rl_games

rl_games with Brax trains too fast, so step_time = 0 -> crash

Closed this issue · 4 comments

Running

python runner.py --train --file rl_games/configs/brax/ppo_ant.yaml

trains so fast that step_time becomes 0.0 and then leads to a crash in two different places:

fps step: 1048568.0 fps step and policy inference: 699047.1  fps total: 299589.9 epoch: 3/1000
fps step: 1048576.0 fps step and policy inference: 699047.1  fps total: 299589.9 epoch: 4/1000
fps step: 699050.7 fps step and policy inference: 419429.1  fps total: 262141.5 epoch: 5/1000
fps step: 524284.0 fps step and policy inference: 524284.0  fps total: 253261.5 epoch: 6/1000
fps step: 613857.2 fps step and policy inference: 613857.2  fps total: 282772.3 epoch: 7/1000
fps step: 699043.6 fps step and policy inference: 524280.0  fps total: 259049.0 epoch: 8/1000
fps step: 699040.0 fps step and policy inference: 419433.0  fps total: 233017.7 epoch: 9/1000
fps step: 699054.2 fps step and policy inference: 524286.0  fps total: 262141.5 epoch: 10/1000
fps step: 699054.2 fps step and policy inference: 524282.0  fps total: 262140.5 epoch: 11/1000
fps step: 699047.1 fps step and policy inference: 524284.0  fps total: 262141.5 epoch: 12/1000
fps step: 699043.6 fps step and policy inference: 349519.1  fps total: 209712.6 epoch: 13/1000
fps step: 524282.0 fps step and policy inference: 349520.9  fps total: 209712.3 epoch: 14/1000
fps step: 699040.0 fps step and policy inference: 349520.9  fps total: 234601.1 epoch: 15/1000
Traceback (most recent call last):
  File "runner.py", line 67, in <module>
    runner.run(args)
  File "F:\dev\rl_games\rl_games\torch_runner.py", line 122, in run
    self.run_train(args)
  File "F:\dev\rl_games\rl_games\torch_runner.py", line 103, in run_train
    agent.train()
  File "F:\dev\rl_games\rl_games\common\a2c_common.py", line 1158, in train
    self.write_stats(total_time, epoch_num, step_time, play_time, update_time, a_losses, c_losses, entropies, kls, last_lr, lr_mul, frame, scaled_time, scaled_play_time, curr_frames)
  File "F:\dev\rl_games\rl_games\common\a2c_common.py", line 284, in write_stats
    self.writer.add_scalar('performance/step_fps', curr_frames / step_time, frame)
ZeroDivisionError: float division by zero
fps step: 744015.2 fps step and policy inference: 488626.7  fps total: 301957.5 epoch: 741/1000
fps step: 699032.9 fps step and policy inference: 523978.2  fps total: 286478.0 epoch: 742/1000
fps step: 795304.5 fps step and policy inference: 432675.5  fps total: 328279.1 epoch: 743/1000
fps step: 2097216.0 fps step and policy inference: 524284.0  fps total: 299591.2 epoch: 744/1000
fps step: 524276.0 fps step and policy inference: 524276.0  fps total: 299587.9 epoch: 745/1000
fps step: 524286.0 fps step and policy inference: 524286.0  fps total: 299591.2 epoch: 746/1000
Traceback (most recent call last):
  File "runner.py", line 67, in <module>
    runner.run(args)
  File "F:\dev\rl_games\rl_games\torch_runner.py", line 122, in run
    self.run_train(args)
  File "F:\dev\rl_games\rl_games\torch_runner.py", line 103, in run_train
    agent.train()
  File "F:\dev\rl_games\rl_games\common\a2c_common.py", line 1153, in train
    fps_step = curr_frames / step_time
ZeroDivisionError: float division by zero

oh, oh. :)

Btw you have pretty strange jumps up to 2m frames.
On my computer it is much more stable:

fps step: 1246295.3 fps step and policy inference: 819567.2 fps total: 541351.9 epoch: 997/1000
fps step: 1241129.5 fps step and policy inference: 825741.9 fps total: 541277.2 epoch: 998/1000
fps step: 1226498.4 fps step and policy inference: 817077.4 fps total: 540681.0 epoch: 999/1000
fps step: 1250775.4 fps step and policy inference: 824889.4 fps total: 542131.3 epoch: 1000/1000
saving next best rewards:  [7628.9634]
=> saving checkpoint 'runs/Ant_brax/nn/Ant_brax.pth'

Ill apply fix a little bit later.
Will add line step_time = max(step_time, 0.00001)

@ViktorM did you fix it? :)