This repository provides the official PyTorch implementation of PromptFix, including pre-trained weights, training and inference code, and our curated dataset used for training.
📢 PromptFix is designed to follow human instructions to process degraded images and remove unwanted elements. It supports a wide range of tasks, such as:
- 🎨 Colorization
- 🧹 Object Removal
- 🌫️ Dehazing
- 💨 Deblurring
- 🖼️ Watermark Removal
- ❄️ Snow Removal
- 🌙 Low-light Enhancement
Built on a diffusion model backbone, PromptFix delivers outstanding performance in correcting image defects while preserving the original structure, utilizing a 20-step denoising process. It also generalizes effectively across different aspect ratios.
- Environment Setup
- Inference
- Download Dataset
- 🧑💻 Training
- 📝 Citing PromptFix
- 🙏 Acknowledgments
⚠️ Disclaimer
Follow the steps below to clone the repository, set up the environment, and install dependencies. The code is tested on Python 3.10.
git clone https://github.com/yeates/PromptFix.git
cd PromptFix
conda create -n promptfix python=3.10 -y
conda activate promptfix
pip install -r requirements.txt
To process the default image examples, run the following command. The pre-trained model weights will be automatically downloaded from Hugging Face and placed under the checkpoints/
directory:
bash scripts/inference.sh
We curated a training dataset exceeding 1 million samples. Each sample includes paired images and instruction and auxiliary text prompts. The dataset covers multiple low-level image processing tasks.
To download the dataset, run the following commands at the project root directory:
bash scripts/download_promptfix_dataset.sh
The dataset includes the following tasks:
Task | Percentage |
---|---|
🎨 Colorization | 29.3% |
🌙 Low-light Enhancement | 20.7% |
🖼️ Watermark Removal | 12.4% |
🧹 Object Removal | 11.9% |
❄️ Snow Removal | 9.7% |
🌫️ Dehazing | 8.9% |
💨 Deblurring | 7.1% |
Total | 100% |
Note: The dataset is packaged into Parquet files, consisting of 100 parts. Each part can be loaded independently. If you want to experiment with a smaller amount of data without downloading the entire dataset, you can download only a few Parquet files.
To train the model, run:
bash scripts/train.sh <GPU_NUMS>
Replace <GPU_NUMS>
with the number of GPUs you wish to use.
Once checkpoints are saved, you need to convert the EMA (Exponential Moving Average) format weights into a loadable checkpoint:
python scripts/convert_ckpt.py --ema-ckpt <EMA_CKPT_PATH> --out-ckpt <OUT_CKPT_PATH>
For example:
python scripts/convert_ckpt.py --ema-ckpt ./train_logs/promptfix/checkpoints/ckpt_epoch_0/state.pth --out-ckpt ./checkpoints/promptfix_epoch_1.ckpt
If you use our dataset or code, please give the repository a star ⭐ and cite our paper:
@inproceedings{yu2024promptfix,
title={PromptFix: You Prompt and We Fix the Photo},
author={Yu, Yongsheng and Zeng, Ziyun and Hua, Hang and Fu, Jianlong and Luo, Jiebo},
booktitle={NeurIPS},
year={2024}
}
We would like to thank the authors of InstructDiffusion, Stable Diffusion, and InstructPix2Pix for sharing their codes.
This repository is part of an open-source research initiative provided for academic and research purposes only. We have not established any official commercial services, products, or web applications related to this project. Use this software at your own risk; it may not meet all your expectations or requirements.
Please note that the PromptFix dataset is curated from open-source research projects and publicly available photo libraries. By using our dataset, you automatically agree to comply with all applicable licenses and terms of use associated with the source data. Furthermore, you acknowledge and agree that neither the dataset nor any models trained using it may be utilized for any commercial purposes.