/greenaug

GreenAug: Green Screen Augmentation Enables Scene Generalisation in Robotic Manipulation

Primary LanguagePythonMIT LicenseMIT

GreenAug: Green Screen Augmentation Enables Scene Generalisation in Robotic Manipulation

Eugene Teoh, Sumit Patidar, Xiao Ma, Stephen James

Website, Paper

This repo contains the following augmentation methods:

  • greenaug.greenaug_random.GreenAugRandom: This applies random textures to the chroma-keyed background. In our paper, we used mil_data.

  • greenaug.greenaug_generative.GreenAugGenerative: This uses the chroma-keyed mask to inpaint realistic or imagined backgrounds using Stable Diffusion.

  • greenaug.greenaug_mask.GreenAugMask: This uses a masking network to isolate backgrounds as dark pixels during inference. One first needs to train a masking network (see instructions below).

  • greenaug.generative_augmentation.GenerativeAugmentation: This is an implementation of generative augmentation (e.g. CACTI, GenAug, ROSIE). The implementation is close to ROSIE but with open-source models (Grounding DINO, Segment Anything, Stable Diffusion).

These augmentation methods can be integrated during policy learning (imitation or reinforcement). In our experiments, we used ACT and Coarse-to-fine Q-Network.

Installation

Install GreenAug as a Python package:

pip install greenaug @ git+https://github.com/eugeneteoh/greenaug.git

To use the generative variants (GreenAugGenerative and GenerativeAugmentation), set the CUDA_HOME environment variable and install cuda-toolkit:

conda create -n greenaug python=3.10 -y
conda activate greenaug
conda env config vars set CUDA_HOME=$CONDA_PREFIX
conda activate greenaug

# Install PyTorch
# Follow instructions at https://pytorch.org/get-started/locally/
conda install pytorch torchvision torchaudio pytorch-cuda=12.1 -c pytorch -c nvidia -y
conda install cuda-toolkit -c nvidia/label/cuda-12.1.1 -y

pip install greenaug[generative] @ git+https://github.com/eugeneteoh/greenaug.git

To use GreenAugMask:

pip install greenaug[mask] @ git+https://github.com/eugeneteoh/greenaug.git

Then see the example below.

Example Usage

Check examples under examples/.

import torch
from greenaug import GreenAugRandom

augmenter = GreenAugRandom()  # This is a torch.nn.Module
out = augmenter(image, ...)

Training GreenAugMask masking network:

# Download data
huggingface-cli download --repo-type dataset eugeneteoh/greenaug --include "GreenScreenDemoCollection/open_drawer_green_screen.mp4" --local-dir "assets/mask/raw/"               
huggingface-cli download --repo-type dataset eugeneteoh/mil_data --include "*.png" --local-dir "assets/mask/background/"               

# Preprocess data
python scripts/preprocess_masking_data.py  

# Train Masking Network
python scripts/train_masking_network.py

# Run example
python examples/greenaug_mask.py --checkpoint /path/to/checkpoint