/DQN_CartPole_tf

强化学习 CartPole环境,Tensorflow实现DQN。

Primary LanguagePythonMIT LicenseMIT

代码 参考修改自: PARL实现DQN,CartPole环境

参考视频: 世界冠军带你从零实践强化学习

介绍

强化学习 CartPole环境,Tensorflow实现DQN。

运行train.py文件即可。 支持gym的两种环境CartPole-v0MountainCar-v0

if __name__ == "__main__":
    main(env="CartPole-v0")
    # main(env="MountainCar-v0")

依赖

gym == 0.21.0
numpy == 1.22.2
tensorflow == 2.4.0

DQN的两大创新点

  1. 经验回放(Experience Repaly)
  2. 固定Q目标(Fixed Q Target)

经验回放(Experience Repaly)

每个时间步agent与环境交互得到的转移样本存储在buffer中。

当进行模型参数的更新时,从buffer中随机抽取batch_size个数据,构造损失函数,利用梯度下降更新参数。

通过这种方式,

  1. 去除数据之间的相关性,缓和了数据分布的差异。
  2. 提高了样本利用率,进而提高了模型学习效率。

为什么要去除数据之间的相关性?

参考:

关于强化学习中经验回放(experience replay)的两个问题?

为什么机器学习中, 要假设我们的数据是独立同分布的?

理解1: 确保数据是独立同分布的。这样,我们搭建的模型是简单、易用的。

理解2: 在一般环境中,智能体得到奖励的情况往往很少。比如在n个格子上,只有1个格子有奖励。智能体在不断地尝试中,大多数情况是没有奖励的。如果没有Experience Repaly,只是每步都进行更新,模型可能会找不到“正确的道路”,陷入局部最优不收敛情况。


固定Q目标(Fixed Q Target)

参考: DQL: Dueling Double DQN, Prioritized Experience Replay, and fixed Q-targets(三下) 【前面一点的内容】

在DQN中,损失函数的定义是,使Q尽可能地逼近Q_target

在实际情况中,Q在变化,作为 “标签”Q_target也在不断地变化。它使得我们的算法更新不稳定,即输出的Q在不断变化,训练的损失曲线轨迹是震荡的。


DQN引入了target_net。具体来说,使用value_net输出Q值,使用target_net输出Q_target值。

target_netvalue_net具有相同的网络结构,但不共享参数。

  1. 在一段时间内,target_net保持不变,只训练value_net。这样,相当于固定住“标签”Q_target,然后使用预测值Q不断逼近。
  2. 一段时间过后,将value_net的权重 复制到 target_net上,完成target_net参数的更新。

通过这种方式,一定程度降低了当前Q值和target_Q值的相关性,提高了算法稳定性。