/bridge_data_v2

Primary LanguagePythonMIT LicenseMIT

Jax BC/RL Implementations for BridgeData V2

This repository provides code for training on BridgeData V2.

We provide implementations for the following subset of methods described in the paper:

  • Goal-conditioned BC
  • Goal-conditioned BC with a diffusion policy
  • Goal-condtioned IQL
  • Goal-conditioned contrastive RL

The code for the language-conditioned BC method may be released soon.

The official implementations and papers for all the methods can be found here:

Please open a GitHub issue if you encounter problems with this code.

Data

The raw dataset (comprised of JPEGs, PNGs, and pkl files) can be downloaded from the website. For training, the raw data needs to be converted into a TFRecord format that is compatible with the data loader. First, use data_processing/bridgedata_raw_to_numpy.py to convert the raw data into numpy files. Then, use data_processing/bridgedata_numpy_to_tfrecord.py to convert the numpy files into TFRecord files.

Training

To start training run the command below. Replace METHOD with one of gc_bc, gc_ddpm_bc, gc_iql, or contrastive_rl_td, and replace NAME with a name for the run.

python experiments/train.py \
    --config experiments/configs/train_config.py:METHOD \
    --bridgedata_config experiments/configs/data_config.py:all \
    --name NAME

Training hyperparameters can be modified in experiments/configs/data_config.py and data parameters (e.g. subsets to include/exclude) can be modified in experiments/configs/train_config.py.

Evaluation

First, set up the robot hardware according to our guide. Install our WidowX robot controller stack from this repo. Then, run the command:

python experiments/eval.py \
    --num_timesteps NUM_TIMESTEPS \
    --video_save_path VIDEO_DIR \
    --checkpoint_path CHECKPOINT_PATH \
    --wandb_run_name WANDB_RUN_NAME \
    --blocking

The script loads some information about the checkpoint from its corresponding WandB run.

Provided Checkpoints

Checkpoints for GCBC, D-GCBC, GCIQL, CRL, and RT-1 are available here. Each checkpoint (except RT-1) has an associated JSON file with its configuration information. To evaluate these checkpoints with the above evaluation script, modify the references to the wandb run configuration to use the dictionary provided in the JSON file instead.

An evaluation script for the RT-1 checkpoint is available in this separate repo (TODO).

We don't currently have checkpoints for ACT or LCBC available but may release them soon.

Environment

The dependencies for this codebase can be installed in a conda environment:

conda create -n jaxrl python=3.10
conda activate jaxrl
pip install -e . 
pip install -r requirements.txt

For GPU:

pip install --upgrade "jax[cuda11_pip]==0.4.13" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

For TPU

pip install --upgrade "jax[tpu]==0.4.13" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html

See the Jax Github page for more details on installing Jax.

Cite

This code is based on dibyaghosh/jaxrl_m.

If you use this code and/or BridgeData V2 in your work, please cite the paper with:

@inproceedings{walke2023bridgedata,
  title={BridgeData V2: A Dataset for Robot Learning at Scale},
  author={Walke, Homer and Black, Kevin and Lee, Abraham and Kim, Moo Jin and Du, Max and Zheng, Chongyi and Zhao, Tony and Hansen-Estruch, Philippe and Vuong, Quan and He, Andre and Myers, Vivek and Fang, Kuan and Finn, Chelsea and Levine, Sergey},
  booktitle={Conference on Robot Learning (CoRL)},
  year={2023}
}