Eugene Teoh, Sumit Patidar, Xiao Ma, Stephen James
This repo contains the following augmentation methods:
-
greenaug.greenaug_random.GreenAugRandom
: This applies random textures to the chroma-keyed background. In our paper, we used mil_data. -
greenaug.greenaug_generative.GreenAugGenerative
: This uses the chroma-keyed mask to inpaint realistic or imagined backgrounds using Stable Diffusion. -
greenaug.greenaug_mask.GreenAugMask
: This uses a masking network to isolate backgrounds as dark pixels during inference. One first needs to train a masking network (see instructions below). -
greenaug.generative_augmentation.GenerativeAugmentation
: This is an implementation of generative augmentation (e.g. CACTI, GenAug, ROSIE). The implementation is close to ROSIE but with open-source models (Grounding DINO, Segment Anything, Stable Diffusion).
These augmentation methods can be integrated during policy learning (imitation or reinforcement). In our experiments, we used ACT and Coarse-to-fine Q-Network.
Install GreenAug as a Python package:
pip install greenaug @ git+https://github.com/eugeneteoh/greenaug.git
To use the generative variants (GreenAugGenerative and GenerativeAugmentation), set the CUDA_HOME
environment variable and install cuda-toolkit
:
conda create -n greenaug python=3.10 -y
conda activate greenaug
conda env config vars set CUDA_HOME=$CONDA_PREFIX
conda activate greenaug
# Install PyTorch
# Follow instructions at https://pytorch.org/get-started/locally/
conda install pytorch torchvision torchaudio pytorch-cuda=12.1 -c pytorch -c nvidia -y
conda install cuda-toolkit -c nvidia/label/cuda-12.1.1 -y
pip install greenaug[generative] @ git+https://github.com/eugeneteoh/greenaug.git
To use GreenAugMask:
pip install greenaug[mask] @ git+https://github.com/eugeneteoh/greenaug.git
Then see the example below.
Check examples under examples/.
import torch
from greenaug import GreenAugRandom
augmenter = GreenAugRandom() # This is a torch.nn.Module
out = augmenter(image, ...)
Training GreenAugMask masking network:
# Download data
huggingface-cli download --repo-type dataset eugeneteoh/greenaug --include "GreenScreenDemoCollection/open_drawer_green_screen.mp4" --local-dir "assets/mask/raw/"
huggingface-cli download --repo-type dataset eugeneteoh/mil_data --include "*.png" --local-dir "assets/mask/background/"
# Preprocess data
python scripts/preprocess_masking_data.py
# Train Masking Network
python scripts/train_masking_network.py
# Run example
python examples/greenaug_mask.py --checkpoint /path/to/checkpoint