Stable Cascade support, reference implementation for consideration
Teriks opened this issue · 3 comments
I am experimenting with Stable Cascade weighted prompting.
I am not sure if you would like this as a pull request, though I figure you might be interested in this code, I have also determined how to make this work with the compel
library.
Here is what I have so far, it is essentially the same as SDXL with a few tweaks.
There is an example usage and output image below the function implementation given here.
This follows the device argument and memory management convention I have in my fork.
This can drop in to embedding_funcs.py
with the addition of from diffusers import StableCascadePriorPipeline, StableCascadeDecoderPipeline
to imports.
@torch.inference_mode()
def get_weighted_text_embeddings_s_cascade(
pipe: StableCascadePriorPipeline | StableCascadeDecoderPipeline
, prompt: str = ""
, neg_prompt: str = ""
, pad_last_block: bool = True
, device: str = None
):
"""
This function can process long prompt with weights, no length limitation
for Stable Cascade
Args:
pipe (StableCascadePriorPipeline | StableCascadeDecoderPipeline)
prompt (str)
neg_prompt (str)
Returns:
prompt_embeds (torch.Tensor)
neg_prompt_embeds (torch.Tensor)
pooled_prompt_embeds (torch.Tensor)
negative_pooled_prompt_embeds (torch.Tensor)
"""
import math
eos = pipe.tokenizer.eos_token_id
device = device if device else pipe.device
prompt_tokens, prompt_weights = get_prompts_tokens_with_weights(
pipe.tokenizer, prompt
)
neg_prompt_tokens, neg_prompt_weights = get_prompts_tokens_with_weights(
pipe.tokenizer, neg_prompt
)
# padding the shorter one
prompt_token_len = len(prompt_tokens)
neg_prompt_token_len = len(neg_prompt_tokens)
if prompt_token_len > neg_prompt_token_len:
# padding the neg_prompt with eos token
neg_prompt_tokens = (
neg_prompt_tokens +
[eos] * abs(prompt_token_len - neg_prompt_token_len)
)
neg_prompt_weights = (
neg_prompt_weights +
[1.0] * abs(prompt_token_len - neg_prompt_token_len)
)
else:
# padding the prompt
prompt_tokens = (
prompt_tokens
+ [eos] * abs(prompt_token_len - neg_prompt_token_len)
)
prompt_weights = (
prompt_weights
+ [1.0] * abs(prompt_token_len - neg_prompt_token_len)
)
embeds = []
neg_embeds = []
prompt_token_groups, prompt_weight_groups = group_tokens_and_weights(
prompt_tokens.copy()
, prompt_weights.copy()
, pad_last_block=pad_last_block
)
neg_prompt_token_groups, neg_prompt_weight_groups = group_tokens_and_weights(
neg_prompt_tokens.copy()
, neg_prompt_weights.copy()
, pad_last_block=pad_last_block
)
# get prompt embeddings one by one is not working.
for i in range(len(prompt_token_groups)):
# get positive prompt embeddings with weights
token_tensor = torch.tensor(
[prompt_token_groups[i]]
, dtype=torch.long, device=device
)
weight_tensor = torch.tensor(
prompt_weight_groups[i]
, dtype=torch.float16
, device=device
)
prompt_embeds_1 = pipe.text_encoder(
token_tensor.to(device)
, output_hidden_states=True
)
prompt_embeds_1_hidden_states = prompt_embeds_1.hidden_states[-1].cpu()
pooled_prompt_embeds = prompt_embeds_1.text_embeds.unsqueeze(1)
prompt_embeds_list = [prompt_embeds_1_hidden_states]
token_embedding = torch.concat(prompt_embeds_list, dim=-1).squeeze(0).to(device)
for j in range(len(weight_tensor)):
if weight_tensor[j] != 1.0:
# ow = weight_tensor[j] - 1
# optional process
# To map number of (0,1) to (-1,1)
# tanh_weight = (math.exp(ow)/(math.exp(ow) + 1) - 0.5) * 2
# weight = 1 + tanh_weight
# add weight method 1:
# token_embedding[j] = token_embedding[j] * weight
# token_embedding[j] = (
# token_embedding[-1] + (token_embedding[j] - token_embedding[-1]) * weight
# )
# add weight method 2:
# token_embedding[j] = (
# token_embedding[-1] + (token_embedding[j] - token_embedding[-1]) * weight_tensor[j]
# )
# add weight method 3:
token_embedding[j] = token_embedding[j] * weight_tensor[j]
token_embedding = token_embedding.unsqueeze(0)
embeds.append(token_embedding.cpu())
# get negative prompt embeddings with weights
neg_token_tensor = torch.tensor(
[neg_prompt_token_groups[i]]
, dtype=torch.long, device=device
)
neg_weight_tensor = torch.tensor(
neg_prompt_weight_groups[i]
, dtype=torch.float16
, device=device
)
neg_prompt_embeds_1 = pipe.text_encoder(
neg_token_tensor.to(device)
, output_hidden_states=True
)
neg_prompt_embeds_1_hidden_states = neg_prompt_embeds_1.hidden_states[-1].cpu()
negative_pooled_prompt_embeds = neg_prompt_embeds_1.text_embeds.unsqueeze(1)
neg_prompt_embeds_list = [neg_prompt_embeds_1_hidden_states]
neg_token_embedding = torch.concat(neg_prompt_embeds_list, dim=-1).squeeze(0).to(device)
for z in range(len(neg_weight_tensor)):
if neg_weight_tensor[z] != 1.0:
# ow = neg_weight_tensor[z] - 1
# neg_weight = 1 + (math.exp(ow)/(math.exp(ow) + 1) - 0.5) * 2
# add weight method 1:
# neg_token_embedding[z] = neg_token_embedding[z] * neg_weight
# neg_token_embedding[z] = (
# neg_token_embedding[-1] + (neg_token_embedding[z] - neg_token_embedding[-1]) * neg_weight
# )
# add weight method 2:
# neg_token_embedding[z] = (
# neg_token_embedding[-1] + (neg_token_embedding[z] - neg_token_embedding[-1]) * neg_weight_tensor[z]
# )
# add weight method 3:
neg_token_embedding[z] = neg_token_embedding[z] * neg_weight_tensor[z]
neg_token_embedding = neg_token_embedding.unsqueeze(0)
neg_embeds.append(neg_token_embedding.cpu())
# Free VRAM & RAM
del prompt_embeds_1_hidden_states, \
neg_prompt_embeds_1_hidden_states, \
prompt_embeds_1, \
neg_prompt_embeds_1
torch.cuda.empty_cache()
prompt_embeds = torch.cat(embeds, dim=1).to(device)
negative_prompt_embeds = torch.cat(neg_embeds, dim=1).to(device)
return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
Example usage:
import torch
from sd_embed.embedding_funcs import get_weighted_text_embeddings_s_cascade
from diffusers import StableCascadePriorPipeline, StableCascadeDecoderPipeline
device = 'cuda'
pos_prompt = "an image of a shiba inu with (blue eyes:1.4), donning a (green) spacesuit, (cartoon style:1.6)"
neg_prompt = "photograph, real"
prior = StableCascadePriorPipeline.from_pretrained("stabilityai/stable-cascade-prior", variant='bf16',
torch_dtype=torch.bfloat16).to(device)
generator = torch.Generator(device=device).manual_seed(0)
(
prompt_embeds,
negative_prompt_embeds,
pooled_prompt_embeds,
negative_prompt_embeds_pooled
) = get_weighted_text_embeddings_s_cascade(prior, pos_prompt, neg_prompt)
prior_output = prior(
generator=generator,
num_inference_steps=20,
guidance_scale=4,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
prompt_embeds_pooled=pooled_prompt_embeds,
negative_prompt_embeds_pooled=negative_prompt_embeds_pooled)
del prompt_embeds, pooled_prompt_embeds, negative_prompt_embeds, negative_prompt_embeds_pooled
prior.to('cpu')
decoder = StableCascadeDecoderPipeline.from_pretrained("stabilityai/stable-cascade", variant='bf16',
torch_dtype=torch.float16).to(device)
(
prompt_embeds,
negative_prompt_embeds,
pooled_prompt_embeds,
negative_prompt_embeds_pooled
) = get_weighted_text_embeddings_s_cascade(decoder, pos_prompt, neg_prompt)
image = decoder(
generator=generator,
num_inference_steps=10,
guidance_scale=0.0,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
prompt_embeds_pooled=pooled_prompt_embeds,
negative_prompt_embeds_pooled=negative_prompt_embeds_pooled,
image_embeddings=prior_output.image_embeddings.half()).images[0]
del prompt_embeds, pooled_prompt_embeds, negative_prompt_embeds, negative_prompt_embeds_pooled
decoder.to('cpu')
image.save('test.png')
Example output:
Oh, yes, please, I will read and test it out and merge it. Thank You!
code merged