/sb3_ros_support

This package provides ROS support for Stable Baselines3. It allows you to train robotics RL agents in the real world and simulations using ROS and SB3.

Primary LanguagePythonMIT LicenseMIT

SB3 ROS Support: The ROS Support Package for Stable Baselines3

License

This package is an extension of the SB3 package that provides ROS support for Stable Baselines3. It allows you to train robotics RL agents in the real world and simulations using ROS.

This package extends the functionality of SB3 models in FRobs_RL package to provide the following features:

  1. Support for goal-conditioned RL tasks
  2. HER (Hindsight Experience Replay) for goal-conditioned RL tasks
  3. Support for training custom environments with RealROS or MultiROS frameworks
  4. Updated for the new version of SB3 (Stable Baselines3) which uses gymnasium instead of gym.

Prerequisites

Before installing this package, make sure you have the following prerequisites:

ROS Installation

This package requires a working installation of ROS. If you haven't installed ROS yet, please follow the official ROS installation guide for your specific operating system. This package has been tested with ROS Noetic version.

ROS Workspace

Before using this package, you need a ROS workspace to build and run your ROS packages. If you are using a different operating system or ROS version, make sure to adapt the commands accordingly. Follow the steps in the official guide to create a workspace if you haven't done already.

Please note that the instructions assume you are using Ubuntu 20.04 and ROS Noetic.

Installation

To get started, follow these steps:

  1. Clone the repository:

    cd ~/catkin_ws/src
    git clone https://github.com/ncbdrck/sb3_ros_support.git
  2. This package relies on several Python packages. You can install them by running the following command:

    # Install pip if you haven't already by running this command
    sudo apt-get install python3-pip
    
    # install the required Python packages by running
    cd ~/catkin_ws/src/sb3_ros_support/
    pip3 install -r requirements.txt
  3. Build the ROS packages and source the environment:

    cd ~/catkin_ws/
    rosdep install --from-paths src --ignore-src -r -y
    catkin build
    source devel/setup.bash

Usage

you can refer to the examples to see how to use this package to train robots using ROS and Stable Baselines3.

It also showcases:

  • How to use RealROS to create a real-world environment for RL applications.
  • Train the Rx200 robot directly in the real world to perform a simple reach task.
  • Use MultiROS framework to create a simulation environment for the same robot and train it in the simulation environment. Then transfer the learned policy to the real-world environment.
  • Train both environments (sim and real) in real-time to obtain a generalised policy that performs well in both environments.

The installation instructions for the examples are provided in the respective repositories.

or you can follow the following example steps to train a robot using this package:

#!/bin/python3

# ROS packages required
import rospy

# simulation or real-world environment framework
import uniros as gym
# or 
# import gym

# the custom ROS based environments (real or sim)
import rl_environments

# Models
from sb3_ros_support.sac import SAC
from sb3_ros_support.sac_goal import SAC_GOAL


if __name__ == '__main__':
   
    # normal environments
    env_base = gym.make('RX200ReacherSim-v0', gazebo_gui=False)

    # or you can use

    # goal-conditioned environments
    env_goal = gym.make('RX200ReacherGoalSim-v0', gazebo_gui=True, ee_action_type=False, 
                        delta_action=False, reward_type="sparse")
   
    # reset the environments
    env_base.reset()
    env_goal.reset()
   
    # create the models
    pkg_path = "rl_environments"
    config_file_name_base = "sac.yaml"
    config_file_name_goal = "sac_goal.yaml"
    save_path = "/models/sac/"
    log_path = "/logs/sac/"

    # --------------------------------------------------------------------------------------------
    # Creating a model - normal environments
    model_base = SAC(env_base, save_path, log_path, model_pkg_path=pkg_path, 
                     config_file_pkg=pkg_path, config_filename=config_file_name_base)
    
    # train the models
    model_base.train()
    model_base.save_model()

    # --------------------------------------------------------------------------------------------
    # Creating a model - goal-conditioned environments
    model_goal = SAC_GOAL(env_goal, save_path, log_path, model_pkg_path=pkg_path, 
                          config_file_pkg=pkg_path, config_filename=config_file_name_goal)
    
    # train the models
    model_goal.train()
    model_goal.save_model()

    # --------------------------------------------------------------------------------------------
    # validate the models
    obs = env_base.reset()
    episodes = 1000
    epi_count = 0
    while epi_count < episodes:
        action, _states = model_base.predict(observation=obs, deterministic=True)
        obs, _, terminated,truncated, info = env_base.step(action)
        if terminated or truncated:
            epi_count += 1
            rospy.logwarn("Episode: " + str(epi_count))
            obs = env_base.reset()

    env_base.close()
    
    # We can also use the goal-conditioned model to validate the normal environment
    # Just follow the same procedure as above. Not shown here.
    env_goal.close()
    
    # If you want to load saved models and validate results, you can use the following code
    model = SAC.load_trained_model(save_path + "trained_model_name_without_.zip", 
                                   model_pkg= pkg_path,
                                   env=env_goal,
                                   config_filename=config_file_name_goal)
    # Then you can follow the same validation procedure as above

Note: Please note that the examples are provided for reference only. You may need to modify the code to suit your specific needs.

License

This package is released under the MIT Licence. Please see the LICENCE file for more details.

Acknowledgements

We would like to thank the following projects and communities for their valuable contributions, as well as the authors of relevant libraries and tools used in this package.

Contact

For questions, suggestions, or collaborations, contact the project maintainer at j.kapukotuwa@research.ait.ie.