Official PyTorch codes for paper Enhancing Diffusion Models with Text-Encoder Reinforcement Learning
- Clone the repo and install required packages with
# git clone this repository
git clone https://github.com/chaofengc/TexForce.git
cd TexForce
# create new anaconda env
conda create -n texforce python=3.8
source activate texforce
# install python dependencies
pip3 install -r requirements.txt
You may simply load the pretrained lora weights with the following code block to improve performance of original stable diffusion model:
from diffusers import StableDiffusionPipeline
from diffusers import DDIMScheduler
from peft import PeftModel
import torch
def load_model_weights(pipe, weight_path, model_type):
if model_type == 'text+lora':
text_encoder = pipe.text_encoder
PeftModel.from_pretrained(text_encoder, weight_path)
elif model_type == 'unet+lora':
pipe.unet.load_attn_procs(weight_path)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model_id = "CompVis/stable-diffusion-v1-4"
pipe = StableDiffusionPipeline.from_pretrained(model_id, dtype=torch.float16)
pipe = pipe.to("cuda")
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
load_model_weights(pipe, './lora_weights/sd14_refl/', 'unet+lora')
load_model_weights(pipe, './lora_weights/sd14_texforce/', 'text+lora')
prompt = ['a painting of a dog.']
img = pipe(prompt).images[0]
Here are some example results:
| SDv1.4 | ReFL | TexForce | ReFL+TexForce |
|---|---|---|---|
|
|||
|
|||
|
|||
If you find this code useful for your research, please cite our paper:
@article{chen2023texforce,
title={Enhancing Diffusion Models with Text-Encoder Reinforcement Learning},
author={Chaofeng Chen and Annan Wang and Haoning Wu and Liang Liao and Wenxiu Sun and Qiong Yan and Weisi Lin},
year={2023},
eprint={2311.15657},
archivePrefix={arXiv},
primaryClass={cs.CV}
}
This work is licensed under a Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International License.




