/crossformer

Primary LanguagePythonMIT LicenseMIT

Scaling Cross-Embodied Learning: One Policy for Manipulation, Navigation, Locomotion and Aviation

arXiv Open In Colab HF Models Python License: MIT Static Badge

Ria Doshi, Homer Walke, Oier Mees, Sudeep Dasari, Sergey Levine


This repo contains code for training and finetuning CrossFormer. CrossFormer is a transformer-based robot policy trained on 900K robot trajectories across 20 different robot embodiments. Our codebase is built on the Octo codebase.

CrossFormer model

Get Started

Follow the installation instructions, then load the pre-trained CrossFormer model! See our colab notebook for an inference example.

from crossformer.model.crossformer_model import CrossFormerModel
model = CrossFormerModel.load_pretrained("hf://rail-berkeley/crossformer")
print(model.get_pretty_spec())

Out of the box, CrossFormer can control single and dual arm manipulation systems, wheeled robots, quadcopters, and quadrupeds, and can be instructed via language commands or goal images. CrossFormer uses a modular attention structure in its transformer backbone, allowing it to be effectively finetuned to robot setups with new sensory inputs, action spaces, and morphologies, using only a small target domain dataset and accessible compute budgets.

Installation

conda create -n crossformer python=3.10
conda activate crossformer
pip install -e .
pip install -r requirements.txt

For GPU:

pip install --upgrade "jax[cuda11_pip]==0.4.20" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

For TPU

pip install --upgrade "jax[tpu]==0.4.20" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html

See the Jax Github page for more details on installing Jax.

Checkpoint

You can find the pre-trained CrossFormer 130M parameter checkpoint here.

CrossFormer Evaluation

See scripts/server.py for an example of how to host the CrossFormer model on a server for remote inference. Remote inference is useful for evaluating on robots that cannot be directly connected to a powerful GPU.

CrossFormer Pre-training

To reproduce CrossFormer pre-training, edit scripts/configs/pretrain_config.py to point to your data and log directory. Then, run:

python scripts/train.py --config scripts/configs/pretrain_config.py

To download the pre-training datasets from the Open X-Embodiment Dataset, install the rlds_dataset_mod package and run the prepare_open_x.sh script.

Pre-training takes 47 hours on a TPUv5-256 pod.

CrossFormer Finetuning

To run finetuning on your own dataset, convert your dataset to the RLDS format using this repository. Then, edit scripts/configs/finetune_config.py, and run:

python scripts/finetune.py --config scripts/configs/finetune_config.py

There are a few options for finetuning CrossFormer. If your dataset has an observation space and action space that was used during pre-training, you can finetune from entirely pre-trained weights, using the existing observation tokenizers, action heads, and transformer backbone. Otherwise, you can initialize new observation tokenizers and/or action heads while keeping the pre-trained transformer backbone. Additionally, you may choose to finetune the entire model or freeze the transformer and finetune only the action head. Finally, you can choose to finetune on your data with goal image conditioning, language conditioning, or both.

See the comments in scripts/configs/finetune_config.py for an explanation of how to configure these fine-tuning options.

Citation

@article{doshi24-crossformer,
    title={Scaling Cross-Embodied Learning: One Policy for Manipulation, Navigation, Locomotion and Aviation},
    author={Ria Doshi and Homer Walke and Oier Mees and Sudeep Dasari and Sergey Levine},
    journal={arXiv preprint arXiv:2408.11812},
    year={2024}
}