/DeepRL

Primary LanguagePython

DeepRL

DeepRL is a pytorch implementation of Deep Reinforcement Learning algorithms. The project contains the following modules:

  • models
  • algorithms
  • evaluator

Models

Models implements a set of common policies and value functions.

Implemented policies are:

  • Gaussian
  • Gaussian Bounded
  • Gaussian Clipped
  • Deterministic
  • Cross Entropy Method (CEM)

Implemented value functions are:

  • ValueFunction
  • ActionValueFunction

Algorithms

The algorithms module implements a base class (BaseRL) for RL agents. Implemented algorithms are:

Trust Region Policy Optimization (TRPO)

Proximal Policy Optimization (PPO)

Twin Delayed Deep Deterministic Policy Gradient (TD3)

Soft Actor-Critic (SAC)

Cross Entropy Guided Policy (CGP)

An example of the exposed interface is shown in the following:

import gym
from algorithms import TD3

env = gym.make('Pendulum-v1')
agent = TD3(env)
done = True
for t in range(timesteps):
  if done:
      state = env.reset()
      state, reward, done = agent.act(state)
      metrics = agent.learn()

Evaluator

The evaluator module addresses the issue of reproducable RL, by seperating the evaluation process from the core algorithms. The evaluation class allows to easily generate statistical evaluation of an algorithm

import gym 
import dm_control2gym
from algorithms import PPO, TRPO, SAC, CGP, TD3
from evaluator import Evaluator
from evaluator.plot import plot_learning_curves, load_dataset

envs = [
    ('cartpole', 'balance'),
    ('cartpole', 'swingup'),
    ('acrobot', 'swingup'),
    ('cheetah', 'run'),
    ('hopper', 'hop'),
    ('walker', 'run')
]

algs = [TD3, CGP, SAC, TRPO, PPO]

for alg in algs:
  for domain, task in envs:
        env = dm_control2gym.make(domain_name=domain, task_name=task)
        agent = alg(env)
        evl = Evaluator(agent, 'path/to/output/data')
        evl.run_statistic(samples=20, seed=0)
returns = load_dataset('path/to/returns_offline.npz')
fig, ax = plot_learning_curves(returns, interval='bs')

Requirements

  • Python3
  • PyTorch
  • Gym
  • MuJoCo
  • dm_control
  • dm_control2gym