/BrushNet

The official implementation of paper "BrushNet: A Plug-and-Play Image Inpainting Model with Decomposed Dual-Branch Diffusion"

Primary LanguagePythonOtherNOASSERTION

BrushNet

This repository contains the implementation of the paper "BrushNet: A Plug-and-Play Image Inpainting Model with Decomposed Dual-Branch Diffusion"

Keywords: Image Inpainting, Diffusion Models, Image Generation

Xuan Ju12, Xian Liu12, Xintao Wang1*, Yuxuan Bian2, Ying Shan1, Qiang Xu2*
1ARC Lab, Tencent PCG 2The Chinese University of Hong Kong *Corresponding Author

🌐Project Page | πŸ“œArxiv | πŸ—„οΈData | πŸ“ΉVideo | πŸ€—Hugging Face Demo |

πŸ“– Table of Contents

TODO

  • Release trainig and inference code
  • Release checkpoint (sdv1.5)
  • Release checkpoint (sdxl)
  • Release evaluation code
  • Release gradio demo

πŸ› οΈ Method Overview

BrushNet is a diffusion-based text-guided image inpainting model that can be plug-and-play into any pre-trained diffusion model. Our architectural design incorporates two key insights: (1) dividing the masked image features and noisy latent reduces the model's learning load, and (2) leveraging dense per-pixel control over the entire pre-trained model enhances its suitability for image inpainting tasks. More analysis can be found in the main paper.

πŸš€ Getting Started

Environment Requirement 🌍

BrushNet has been implemented and tested on Pytorch 1.12.1 with python 3.9.

Clone the repo:

git clone https://github.com/TencentARC/BrushNet.git

We recommend you first use conda to create virtual environment, and install pytorch following official instructions. For example:

conda create -n diffusers python=3.9 -y
conda activate diffusers
python -m pip install --upgrade pip
pip install torch==1.12.1+cu116 torchvision==0.13.1+cu116 torchaudio==0.12.1 --extra-index-url https://download.pytorch.org/whl/cu116

Then, you can install diffusers (implemented in this repo) with:

pip install -e .

After that, you can install required packages thourgh:

cd examples/brushnet/
pip install -r requirements.txt

Data Download ⬇️

Dataset

You can download the BrushData and BrushBench here (as well as the EditBench we re-processed), which are used for training and testing the BrushNet. By downloading the data, you are agreeing to the terms and conditions of the license. The data structure should be like:

|-- data
    |-- BrushData
        |-- 00200.tar
        |-- 00201.tar
        |-- ...
    |-- BrushDench
        |-- images
        |-- mapping_file.json
    |-- EditBench
        |-- images
        |-- mapping_file.json

Noted: We only provide a part of the BrushData due to the space limit. Please write an email to juxuan.27@gmail.com if you need the full dataset.

Checkpoints

Checkpoints of BrushNet can be downloaded from here. The ckpt folder contains our pretrained checkpoints and pretrinaed Stable Diffusion checkpoint (e.g., realisticVisionV60B1_v51VAE from Civitai). You can use scripts/convert_original_stable_diffusion_to_diffusers.py to process other models downloaded from Civitai. The data structure should be like:

|-- data
    |-- BrushData
    |-- BrushDench
    |-- EditBench
    |-- ckpt
        |-- realisticVisionV60B1_v51VAE
            |-- model_index.json
            |-- vae
            |-- ...
        |-- segmentation_mask_brushnet_ckpt
        |-- random_mask_brushnet_ckpt
        |-- ...

The checkpoint in segmentation_mask_brushnet_ckpt provides checkpoints trained on BrushData, which has segmentation prior (mask are with the same shape of objects). The random_mask_brushnet_ckpt provides a more general ckpt for random mask shape.

πŸƒπŸΌ Running Scripts

Training 🀯

You can train with segmentation mask using the script:

accelerate launch examples/brushnet/train_brushnet.py \
--pretrained_model_name_or_path runwayml/stable-diffusion-v1-5 \
--output_dir runs/logs/brushnet_segmentationmask \
--train_data_dir data/BrushData \
--resolution 512 \
--learning_rate 1e-5 \
--train_batch_size 2 \
--tracker_project_name brushnet \
--report_to tensorboard \
--resume_from_checkpoint latest \
--validation_steps 300

To use custom dataset, you can process your own data to the format of BrushData and revise --train_data_dir.

You can train with random mask using the script (by adding --random_mask):

accelerate launch examples/brushnet/train_brushnet.py \
--pretrained_model_name_or_path runwayml/stable-diffusion-v1-5 \
--output_dir runs/logs/brushnet_randommask \
--train_data_dir data/BrushData \
--resolution 512 \
--learning_rate 1e-5 \
--train_batch_size 2 \
--tracker_project_name brushnet \
--report_to tensorboard \
--resume_from_checkpoint latest \
--validation_steps 300 \
--random_mask

Inference πŸ“œ

You can inference with the script:

python examples/brushnet/test_brushnet.py

Since BrushNet is trained on Laion, it can only guarantee the performance on general scenarios. We recommend you train on your own data (e.g., product exhibition, virtual try-on) if you have high-quality industrial application requirements. We would also be appreciate if you would like to contribute your trained model!

You can also inference through gradio demo:

python examples/brushnet/app_brushnet.py

Evaluation πŸ“

You can evaluate using the script:

python examples/brushnet/evaluate_brushnet.py \
--brushnet_ckpt_path data/ckpt/segmentation_mask_brushnet_ckpt \
--image_save_path runs/evaluation_result/BrushBench/brushnet_segmask/inside \
--mapping_file data/BrushBench/mapping_file.json \
--base_dir data/BrushBench \
--mask_key inpainting_mask

The --mask_key indicates which kind of mask to use, inpainting_mask for inside inpainting and outpainting_mask for outside inpainting. The evaluation results (images and metrics) will be saved in --image_save_path.

Noted that you need to ignore the nsfw detector in src/diffusers/pipelines/brushnet/pipeline_brushnet.py#1261 to get the correct evaluation results. Moreover, we find different machine may generate different images, thus providing the results on our machine here.

🀝🏼 Cite Us

@misc{ju2024brushnet,
  title={BrushNet: A Plug-and-Play Image Inpainting Model with Decomposed Dual-Branch Diffusion}, 
  author={Xuan Ju and Xian Liu and Xintao Wang and Yuxuan Bian and Ying Shan and Qiang Xu},
  year={2024},
  eprint={2403.06976},
  archivePrefix={arXiv},
  primaryClass={cs.CV}
}

πŸ’– Acknowledgement

Our code is modified based on diffusers, thanks to all the contributors!