xhinker/sd_embed

Stable Cascade support, reference implementation for consideration

Teriks opened this issue · 3 comments

I am experimenting with Stable Cascade weighted prompting.

I am not sure if you would like this as a pull request, though I figure you might be interested in this code, I have also determined how to make this work with the compel library.

Here is what I have so far, it is essentially the same as SDXL with a few tweaks.

There is an example usage and output image below the function implementation given here.

This follows the device argument and memory management convention I have in my fork.

This can drop in to embedding_funcs.py with the addition of from diffusers import StableCascadePriorPipeline, StableCascadeDecoderPipeline to imports.

@torch.inference_mode()
def get_weighted_text_embeddings_s_cascade(
        pipe: StableCascadePriorPipeline | StableCascadeDecoderPipeline
        , prompt: str = ""
        , neg_prompt: str = ""
        , pad_last_block: bool = True
        , device: str = None
):
    """
     This function can process long prompt with weights, no length limitation
     for Stable Cascade
     
     Args:
         pipe (StableCascadePriorPipeline | StableCascadeDecoderPipeline)
         prompt (str)
         neg_prompt (str)
     Returns:
         prompt_embeds (torch.Tensor)
         neg_prompt_embeds (torch.Tensor)
         pooled_prompt_embeds (torch.Tensor)
         negative_pooled_prompt_embeds (torch.Tensor)
         
     """
    import math
    eos = pipe.tokenizer.eos_token_id

    device = device if device else pipe.device

    prompt_tokens, prompt_weights = get_prompts_tokens_with_weights(
        pipe.tokenizer, prompt
    )

    neg_prompt_tokens, neg_prompt_weights = get_prompts_tokens_with_weights(
        pipe.tokenizer, neg_prompt
    )

    # padding the shorter one
    prompt_token_len = len(prompt_tokens)
    neg_prompt_token_len = len(neg_prompt_tokens)

    if prompt_token_len > neg_prompt_token_len:
        # padding the neg_prompt with eos token
        neg_prompt_tokens = (
                neg_prompt_tokens +
                [eos] * abs(prompt_token_len - neg_prompt_token_len)
        )
        neg_prompt_weights = (
                neg_prompt_weights +
                [1.0] * abs(prompt_token_len - neg_prompt_token_len)
        )
    else:
        # padding the prompt
        prompt_tokens = (
                prompt_tokens
                + [eos] * abs(prompt_token_len - neg_prompt_token_len)
        )
        prompt_weights = (
                prompt_weights
                + [1.0] * abs(prompt_token_len - neg_prompt_token_len)
        )

    embeds = []
    neg_embeds = []

    prompt_token_groups, prompt_weight_groups = group_tokens_and_weights(
        prompt_tokens.copy()
        , prompt_weights.copy()
        , pad_last_block=pad_last_block
    )

    neg_prompt_token_groups, neg_prompt_weight_groups = group_tokens_and_weights(
        neg_prompt_tokens.copy()
        , neg_prompt_weights.copy()
        , pad_last_block=pad_last_block
    )

    # get prompt embeddings one by one is not working. 
    for i in range(len(prompt_token_groups)):
        # get positive prompt embeddings with weights
        token_tensor = torch.tensor(
            [prompt_token_groups[i]]
            , dtype=torch.long, device=device
        )
        weight_tensor = torch.tensor(
            prompt_weight_groups[i]
            , dtype=torch.float16
            , device=device
        )

        prompt_embeds_1 = pipe.text_encoder(
            token_tensor.to(device)
            , output_hidden_states=True
        )
        prompt_embeds_1_hidden_states = prompt_embeds_1.hidden_states[-1].cpu()

        pooled_prompt_embeds = prompt_embeds_1.text_embeds.unsqueeze(1)

        prompt_embeds_list = [prompt_embeds_1_hidden_states]
        token_embedding = torch.concat(prompt_embeds_list, dim=-1).squeeze(0).to(device)

        for j in range(len(weight_tensor)):
            if weight_tensor[j] != 1.0:
                # ow = weight_tensor[j] - 1

                # optional process
                # To map number of (0,1) to (-1,1)
                # tanh_weight = (math.exp(ow)/(math.exp(ow) + 1) - 0.5) * 2
                # weight = 1 + tanh_weight

                # add weight method 1:
                # token_embedding[j] = token_embedding[j] * weight
                # token_embedding[j] = (
                #     token_embedding[-1] + (token_embedding[j] - token_embedding[-1]) * weight
                # )

                # add weight method 2:
                # token_embedding[j] = (
                #     token_embedding[-1] + (token_embedding[j] - token_embedding[-1]) * weight_tensor[j]
                # )

                # add weight method 3:
                token_embedding[j] = token_embedding[j] * weight_tensor[j]

        token_embedding = token_embedding.unsqueeze(0)
        embeds.append(token_embedding.cpu())

        # get negative prompt embeddings with weights
        neg_token_tensor = torch.tensor(
            [neg_prompt_token_groups[i]]
            , dtype=torch.long, device=device
        )

        neg_weight_tensor = torch.tensor(
            neg_prompt_weight_groups[i]
            , dtype=torch.float16
            , device=device
        )

        neg_prompt_embeds_1 = pipe.text_encoder(
            neg_token_tensor.to(device)
            , output_hidden_states=True
        )
        neg_prompt_embeds_1_hidden_states = neg_prompt_embeds_1.hidden_states[-1].cpu()
        negative_pooled_prompt_embeds = neg_prompt_embeds_1.text_embeds.unsqueeze(1)

        neg_prompt_embeds_list = [neg_prompt_embeds_1_hidden_states]
        neg_token_embedding = torch.concat(neg_prompt_embeds_list, dim=-1).squeeze(0).to(device)

        for z in range(len(neg_weight_tensor)):
            if neg_weight_tensor[z] != 1.0:
                # ow = neg_weight_tensor[z] - 1
                # neg_weight = 1 + (math.exp(ow)/(math.exp(ow) + 1) - 0.5) * 2

                # add weight method 1:
                # neg_token_embedding[z] = neg_token_embedding[z] * neg_weight
                # neg_token_embedding[z] = (
                #     neg_token_embedding[-1] + (neg_token_embedding[z] - neg_token_embedding[-1]) * neg_weight
                # )

                # add weight method 2:
                # neg_token_embedding[z] = (
                #     neg_token_embedding[-1] + (neg_token_embedding[z] - neg_token_embedding[-1]) * neg_weight_tensor[z]
                # )

                # add weight method 3:
                neg_token_embedding[z] = neg_token_embedding[z] * neg_weight_tensor[z]

        neg_token_embedding = neg_token_embedding.unsqueeze(0)
        neg_embeds.append(neg_token_embedding.cpu())

        # Free VRAM & RAM
        del prompt_embeds_1_hidden_states, \
            neg_prompt_embeds_1_hidden_states, \
            prompt_embeds_1, \
            neg_prompt_embeds_1
        torch.cuda.empty_cache()

    prompt_embeds = torch.cat(embeds, dim=1).to(device)
    negative_prompt_embeds = torch.cat(neg_embeds, dim=1).to(device)

    return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds

Example usage:

import torch
from sd_embed.embedding_funcs import get_weighted_text_embeddings_s_cascade
from diffusers import StableCascadePriorPipeline, StableCascadeDecoderPipeline

device = 'cuda'

pos_prompt = "an image of a shiba inu with (blue eyes:1.4), donning a (green) spacesuit, (cartoon style:1.6)"
neg_prompt = "photograph, real"

prior = StableCascadePriorPipeline.from_pretrained("stabilityai/stable-cascade-prior", variant='bf16',
                                                   torch_dtype=torch.bfloat16).to(device)

generator = torch.Generator(device=device).manual_seed(0)

(
    prompt_embeds,
    negative_prompt_embeds,
    pooled_prompt_embeds,
    negative_prompt_embeds_pooled
) = get_weighted_text_embeddings_s_cascade(prior, pos_prompt, neg_prompt)

prior_output = prior(
    generator=generator,
    num_inference_steps=20,
    guidance_scale=4,
    prompt_embeds=prompt_embeds,
    negative_prompt_embeds=negative_prompt_embeds,
    prompt_embeds_pooled=pooled_prompt_embeds,
    negative_prompt_embeds_pooled=negative_prompt_embeds_pooled)

del prompt_embeds, pooled_prompt_embeds, negative_prompt_embeds, negative_prompt_embeds_pooled
prior.to('cpu')

decoder = StableCascadeDecoderPipeline.from_pretrained("stabilityai/stable-cascade", variant='bf16',
                                                       torch_dtype=torch.float16).to(device)

(
    prompt_embeds,
    negative_prompt_embeds,
    pooled_prompt_embeds,
    negative_prompt_embeds_pooled
) = get_weighted_text_embeddings_s_cascade(decoder, pos_prompt, neg_prompt)

image = decoder(
    generator=generator,
    num_inference_steps=10,
    guidance_scale=0.0,
    prompt_embeds=prompt_embeds,
    negative_prompt_embeds=negative_prompt_embeds,
    prompt_embeds_pooled=pooled_prompt_embeds,
    negative_prompt_embeds_pooled=negative_prompt_embeds_pooled,
    image_embeddings=prior_output.image_embeddings.half()).images[0]

del prompt_embeds, pooled_prompt_embeds, negative_prompt_embeds, negative_prompt_embeds_pooled
decoder.to('cpu')

image.save('test.png')

Example output:

test

Oh, yes, please, I will read and test it out and merge it. Thank You!

Add stable cascade support #19

code merged