/SaRA

SaRA: High-Efficient Diffusion Model Fine-tuning with Progressive Sparse Low-Rank Adaptation

Primary LanguagePython

SaRA: High-Efficient Diffusion Model Fine-tuning with Progressive Sparse Low-Rank Adaptation

Teng Hu, Jiangning Zhang, Ran Yi, Hongrui Huang, Yabiao Wang, and Lizhuang Ma

πŸ› οΈ Installation

  • Env: We have tested on Python 3.9.5 and CUDA 11.8 (other versions may also be fine).
  • Dependencies: pip install -r requirements.txt

πŸš€ Quick Start

πŸš€ Run SaRA by modifying a single line of code

you can easily employ SaRA to finetune the pre-trained model by modifying a single line of code:

from optim import adamw
model = Initialize_model()
optimizer = adamw(model,threshold=2e-3)   # modify this line only
for data in dataloader:
    model.train()
model.save()

πŸš€ Save and load only the trainable parameters

If you want to save only the trainable parameters, you can use optimizer.save_params(), which can save only the fien-tuned parameters (e.g, 5M, 10M parameters), rather than the whole model.

optimizer = adamw(model,threshold=2e-3)
optimizer.load($path_to_save)
torch.save(optimizer.save_params(),$path_to_save)

🍺 Examples

πŸ“– Datasets

For the downstream dataset fine-tuning task, we employ five dataset, including BarbieCore, CyberPunk, ElementFire, Expedition, and Hornify (Google Drive). Each dataset is structured as:

dataset_name
   β”œβ”€β”€ name1.png
   β”œβ”€β”€ name2.png
   β”œβ”€β”€ ...
   β”œβ”€β”€ metadata.jsonl

where metadata.jsonl contains the prompts (captioned by BLIP) for each image.

πŸš€Fine-tuning on downstream dataset

Put the downloaded datasets in examples/dataset, and then run:

cd examples
python3 finetune.py \
   --config=configs/Barbie.json \
   --output_dir=$path_to_save \
   --sd_version=1.5 \
   --threshold=2e-3 \
   --lr_scheduler=cosine \
   --progressive_iter=2500 \
   --lambda_rank=0.0005\

Or you can just run bash finetune.sh.

πŸš€Fine-tuning Dreambooth

Coming Soon

πŸš€Fine-tuning Animatediff

Coming Soon