/OfflineRL-Kit

An elegant PyTorch offline reinforcement learning library for researchers.

Primary LanguagePythonMIT LicenseMIT


OfflineRL-Kit: An elegant PyTorch offline reinforcement learning library.

MIT

OfflineRL-Kit is an offline reinforcement learning library based on pure PyTorch. This library has some features which are friendly and convenient for researchers, including:

  • Elegant framework, the code structure is very clear and easy to use
  • State-of-the-art offline RL algorithms, including model-free and model-based approaches
  • High scalability, you can build your new algorithm with few lines of code based on the components in our library
  • Support parallel tuning, very convenient for researchers
  • Clear and powerful log system, easy to manage experiments

Supported algorithms

Benchmark Results (4 seeds) (Ongoing)

CQL TD3+BC EDAC IQL MOPO RAMBO COMBO
halfcheetah-medium-v2 49.4±0.2 48.2±0.5 66.4±1.1 47.4±0.5 72.4±4.2 78.7±1.1 71.9±8.5
hopper-medium-v2 59.1±4.1 60.8±3.4 101.8±0.2 65.7±8.1 62.8±38.1 82.1±38.0 84.7±9.3
walker2d-medium-v2 83.6±0.5 84.4±2.1 93.3±0.8 81.1±2.6 84.1±3.2 86.1±1.0 83.9±2.0
halfcheetah-medium-replay-v2 47.0±0.3 45.0±0.5 62.3±1.4 44.2±0.6 72.1±3.8 68.5±3.6 66.5±6.5
hopper-medium-replay-v2 98.6±1.5 67.3±13.2 101.5±0.1 94.8±6.7 92.7±20.7 93.4±11.4 90.1±25.2
walker2d-medium-replay-v2 71.3±17.9 83.4±7.0 86.2±1.2 77.3±11.0 85.9±5.3 73.7±6.5 89.4±6.4
halfcheetah-medium-expert-v2 93.0±2.2 90.7±2.7 101.8±8.4 88.0±2.8 83.6±12.5 98.8±4.3 98.2±0.2
hopper-medium-expert-v2 111.4±0.5 91.4±11.3 110.5±0.3 106.2±5.6 74.6±44.2 85.0±30.7 108.8±2.6
walker2d-medium-expert-v2 109.8±0.5 110.2±0.3 113.6±0.3 108.3±2.6 108.2±4.3 78.4±45.4 110.0±0.2

Detailed logs can be viewed in https://drive.google.com/drive/folders/11QHHDlLmUEc097tPgYvb4gZ2IaqSpkHp?usp=share_link.

Installation

git clone https://github.com/yihaosun1124/OfflineRL-Kit.git
cd OfflineRL-Kit
python setup.py install

Quick Start

Train

This is an example of CQL. You can also run the full script at run_example/run_cql.py.

First, make an environment and get the offline dataset:

env = gym.make(args.task)
dataset = qlearning_dataset(env)
buffer = ReplayBuffer(
    buffer_size=len(dataset["observations"]),
    obs_shape=args.obs_shape,
    obs_dtype=np.float32,
    action_dim=args.action_dim,
    action_dtype=np.float32,
    device=args.device
)
buffer.load_dataset(dataset)

Define the models and optimizers:

actor_backbone = MLP(input_dim=np.prod(args.obs_shape), hidden_dims=args.hidden_dims)
critic1_backbone = MLP(input_dim=np.prod(args.obs_shape) + args.action_dim, hidden_dims=args.hidden_dims)
critic2_backbone = MLP(input_dim=np.prod(args.obs_shape) + args.action_dim, hidden_dims=args.hidden_dims)
dist = TanhDiagGaussian(
    latent_dim=getattr(actor_backbone, "output_dim"),
    output_dim=args.action_dim,
    unbounded=True,
    conditioned_sigma=True
)
actor = ActorProb(actor_backbone, dist, args.device)
critic1 = Critic(critic1_backbone, args.device)
critic2 = Critic(critic2_backbone, args.device)
actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr)
critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr)
critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr)

Setup policy:

policy = CQLPolicy(
    actor,
    critic1,
    critic2,
    actor_optim,
    critic1_optim,
    critic2_optim,
    action_space=env.action_space,
    tau=args.tau,
    gamma=args.gamma,
    alpha=alpha,
    cql_weight=args.cql_weight,
    temperature=args.temperature,
    max_q_backup=args.max_q_backup,
    deterministic_backup=args.deterministic_backup,
    with_lagrange=args.with_lagrange,
    lagrange_threshold=args.lagrange_threshold,
    cql_alpha_lr=args.cql_alpha_lr,
    num_repeart_actions=args.num_repeat_actions
)

Define logger:

log_dirs = make_log_dirs(args.task, args.algo_name, args.seed, vars(args))
output_config = {
    "consoleout_backup": "stdout",
    "policy_training_progress": "csv",
    "tb": "tensorboard"
}
logger = Logger(log_dirs, output_config)
logger.log_hyperparameters(vars(args))

Load all components into the trainer and train it:

policy_trainer = MFPolicyTrainer(
    policy=policy,
    eval_env=env,
    buffer=buffer,
    logger=logger,
    epoch=args.epoch,
    step_per_epoch=args.step_per_epoch,
    batch_size=args.batch_size,
    eval_episodes=args.eval_episodes
)

policy_trainer.train()

Tune

You can easily tune your algorithm with the help of Ray:

ray.init()
# load default args
args = get_args()

config = {}
real_ratios = [0.05, 0.5]
seeds = list(range(2))
config["real_ratio"] = tune.grid_search(real_ratios)
config["seed"] = tune.grid_search(seeds)

analysis = tune.run(
    run_exp,
    name="tune_mopo",
    config=config,
    resources_per_trial={
        "gpu": 0.5
    }
)

You can see the full script at tune_example/tune_mopo.py.

Log

Our logger supports a variant of record file types, including .txt(backup for stdout), .csv(records loss or performance or other metrics in training progress), .tfevents (tensorboard for visualizing the training curve), .json(backup for hyper-parameters). Our logger also has a clear log structure:

└─log(root dir)
    └─task
        └─algo_0
        |   └─seed_0&timestamp_xxx
        |   |   ├─checkpoint
        |   |   ├─model
        |   |   ├─record
        |   |   │  ├─tb
        |   |   │  ├─consoleout_backup.txt
        |   |   │  ├─policy_training_progress.csv
        |   |   │  ├─hyper_param.json
        |   |   ├─result
        |   └─seed_1&timestamp_xxx
        └─algo_1

This is an example of logger and you can see the full script at offlinerlkit/policy_trainer/mb_policy_trainer.py.

First, import some relevant packages:

from offlinerlkit.utils.logger import Logger, make_log_dirs

Then initialize logger:

log_dirs = make_log_dirs(args.task, args.algo_name, args.seed, vars(args))
# key: output file name, value: output handler type
output_config = {
    "consoleout_backup": "stdout",
    "policy_training_progress": "csv",
    "dynamics_training_progress": "csv",
    "tb": "tensorboard"
}
logger = Logger(log_dirs, output_config)
logger.log_hyperparameters(vars(args))

Let's log some metrics:

# log
logger.logkv("eval/normalized_episode_reward", norm_ep_rew_mean)
logger.logkv("eval/normalized_episode_reward_std", norm_ep_rew_std)
logger.logkv("eval/episode_length", ep_length_mean)
logger.logkv("eval/episode_length_std", ep_length_std)
# set timestep
logger.set_timestep(num_timesteps)
# dump results to the record files
logger.dumpkvs()

Plot

python run_example/plotter.py --algos "mopo" "cql" --task "hopper-medium-replay-v2"

Citing OfflineRL-Kit

If you use OfflineRL-Kit in your work, please use the following bibtex

@misc{offinerlkit,
  author = {Yihao Sun},
  title = {OfflineRL-Kit: An Elegant PyTorch Offline Reinforcement Learning Library},
  year = {2023},
  publisher = {GitHub},
  journal = {GitHub repository},
  howpublished = {\url{https://github.com/yihaosun1124/OfflineRL-Kit}},
}