JoaoLages/diffusers-interpret

StableDiffusionPipelineExplainer enable_attention_slicing() and limit token attribution

TomPham97 opened this issue · 4 comments

Version 0.3.0 of the 🤗 Difusers introduces enable_attention_slicing, and I wonder if there's a way to implement this in the explainer. Below is the code that I used and it ran out of CUDA memory:

# Import pipeline
import torch
from diffusers import StableDiffusionPipeline

torch_device = "cuda" if torch.cuda.is_available() else "cpu"

pipe = StableDiffusionPipeline.from_pretrained(
    "CompVis/stable-diffusion-v1-4",
    use_auth_token = True,
    revision = "fp16" if torch_device != "cpu" else None,
    torch_dtype = torch.float16 if torch_device != "cpu" else None)

pipe.to(torch_device)

pipe.enable_attention_slicing() # attention optimization for less memory usage

# Pass pipeline to the explainer class

from diffusers_interpret import StableDiffusionPipelineExplainer

explainer = StableDiffusionPipelineExplainer(pipe)

prompt = "photograph, piggy, corn salad"

with torch.autocast(torch_device):
    output = explainer(prompt,
                       guidance_scale=7.5,
                       num_inference_steps=17)

output.image

Attention slicing is fully compatible with the explainer :)
However, it is expected that you will run out of GPU memory because the attributions calculation is pretty expensive.
I recommend you add the arguments I used in this notebook.

Enabling gradient_checkpointing helps.

explainer = StableDiffusionPipelineExplainer(
    pipe,
    
    # We pass `True` in here to be able to have a higher `n_last_diffusion_steps_to_consider_for_attributions` in the cell below
    gradient_checkpointing=True 
)

Also, reducing n_last_diffusion_steps_to_consider_for_attributions, height and width decreases GPU usage.

prompt = "A cute corgi with the Eiffel Tower in the background"

generator = torch.Generator(device).manual_seed(2023)
with torch.autocast('cuda') if device == 'cuda' else nullcontext():
    output = explainer(
        prompt, 
        num_inference_steps=50, 
        generator=generator,
        height=448,
        width=448,
        
        # for this model, the GPU VRAM usage will raise drastically if we increase this argument. feel free to experiment with it
        # if you are not interested in checking the token attributions, you can pass 0 in here
        n_last_diffusion_steps_to_consider_for_attributions=5
    )

Thank you for your reply ☺️

Once I've passed the argument gradient_checkpointing = True, all 51 inference steps went through without memory constraint. Not surprisingly, CUDA ran out of memory soon after the Calculating token attributions... prompt apppeared.

I've then passed the argument n_last_diffusion_steps_to_consider_for_attributions = 0, but this appeared from explainer.py line 194:

NotImplementedError: Only `attribution_method='grad_x_input'` is implemented for now

On the other hand, passing the argument to 1 or None yielded out of CUDA memory after all the inference steps.

I've then passed the argument n_last_diffusion_steps_to_consider_for_attributions = 0, but this appeared from explainer.py line 194:

NotImplementedError: Only attribution_method='grad_x_input' is implemented for now

Yes, that's my bad... I added a new patch version to fix that.

On the other hand, passing the argument to 1 or None yielded out of CUDA memory after all the inference steps.

To clarify, n_last_diffusion_steps_to_consider_for_attributions=None will calculate attributions for the entire diffusion process while n_last_diffusion_steps_to_consider_for_attributions=0 will not compute attributions at all. I need to write documentation on this!

If with n_last_diffusion_steps_to_consider_for_attributions=1 you are still running out of memory (and using gradient checkpointing and attention slicing), that means that your GPU is just not big enough to compute the gradients for an image of (512, 512) 😞

I suggest using cpu instead or reducing the image size. On Colab it works with a (448, 448).
This stable diffusion model is pretty big, I'm still looking for more ways to require less GPU VRAM.

Yes, that's my bad... I added a new patch version to fix that.

Thank you very much, and this feature is working as intended!