Add acceleration support for FLUX models
oreasono opened this issue ยท 11 comments
๐ The feature, motivation and pitch
https://blackforestlabs.ai/#get-flux
FLUX models are the new SOTA opensource text-to-image model. I am wondering if this slightly different architecture model can still gain benifit from onediff framework.
Alternatives
diffusers library has a FluxPipeline support https://github.com/black-forest-labs/flux?tab=readme-ov-file#diffusers-integration
Additional context
No response
bump I am also looking for this
We are working on FLUX related optimization.
If you are interested to try now, you can try this, it should works: https://github.com/siliconflow/onediff/tree/main/src/onediff/infer_compiler
from onediff.infer_compiler import compile
# module is the model you want to compile
options = '{"mode": "O3"}' # mode can be O2 or O3
flux_pipe = ...
flux_pipe.transformer = compile(flux_pipe.transformer, backend="nexfort", options=options)
We are working on FLUX related optimization.
If you are interested to try now, you can try this, it should works: https://github.com/siliconflow/onediff/tree/main/src/onediff/infer_compiler
from onediff.infer_compiler import compile # module is the model you want to compile options = '{"mode": "O3"}' # mode can be O2 or O3 flux_pipe = ... flux_pipe.transformer = compile(flux_pipe.transformer, backend="nexfort", options=options)
I keep having this error: RuntimeError: RuntimeError: Unsupported timesteps dtype: c10::BFloat16
I keep having this error:
RuntimeError: RuntimeError: Unsupported timesteps dtype: c10::BFloat16
update nexfort, and then set these environment variables:
os.environ['NEXFORT_FUSE_TIMESTEP_EMBEDDING'] = '0'
Hi, I haven't been able to run FLUX diffusers pipeline with nextfort compiler backend. I'm using the settings discussed in this thread, and also based on the official configurations. @strint any suggestion?
Environnent
Torch version: 2.3.0
CUDA version: 12.1.0
GPU: NVIDIA A100-SXM4-80GB
Nexfort version: 0.1.dev264
Onediff/onediffx: 1.2.1.dev18+g6b53a83b
Build from source for dev
Diffusers version: 0.30.0
Test Script
# %%
import os
from onediffx import compile_pipe
from onediff.infer_compiler import compile
import torch
from torch.utils.flop_counter import FlopCounterMode
from triton.testing import do_bench
from diffusers.pipelines.flux.pipeline_flux import FluxPipeline
torch.set_default_device('cuda')
# RuntimeError: RuntimeError: Unsupported timesteps dtype: c10::BFloat16
# ref: https://github.com/siliconflow/onediff/issues/1066#issuecomment-2271523799
os.environ['NEXFORT_FUSE_TIMESTEP_EMBEDDING'] = '0'
# %%
model_id: str = "black-forest-labs/FLUX.1-schnell"
pipe = FluxPipeline.from_pretrained(model_id,
torch_dtype=torch.bfloat16,
)
pipe.to("cuda")
# %%
"""
options = '{"mode": "O3"}'
pipe.transformer = compile(pipe.transformer, backend="nexfort", options=options)
"""
# compiler optimization options: '{"mode": "O3", "memory_format": "channels_last"}'
options = '{"mode": "O2"}'
pipe = compile_pipe(pipe, backend="nexfort", options=options, fuse_qkv_projections=True)
# %%
prompt = "A cat holding a sign that says hello world"
image = pipe(
prompt,
guidance_scale=0.0,
num_inference_steps=4,
max_sequence_length=256,
generator=torch.Generator("cpu").manual_seed(0)
).images[0]
image.save("flux-schnell.png")
# %%
def get_flops_achieved(f):
flop_counter = FlopCounterMode(display=False)
with flop_counter:
f()
total_flops = flop_counter.get_total_flops()
ms_per_iter = do_bench(f)
iters_per_second = 1e3/ms_per_iter
print(f"{iters_per_second * total_flops / 1e12} TF/s")
get_flops_achieved(lambda: pipe(
"A tree in the forest",
guidance_scale=0.0,
num_inference_steps=4,
max_sequence_length=256,
generator=torch.Generator("cpu").manual_seed(0)
))
# %%
def benchmark_torch_function(iters, f, *args, **kwargs):
f(*args, **kwargs)
f(*args, **kwargs)
torch.cuda.synchronize()
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
start_event.record()
for _ in range(iters):
f(*args, **kwargs)
end_event.record()
torch.cuda.synchronize()
# elapsed_time has a resolution of 0.5 microseconds:
# but returns milliseconds, so we need to multiply it to increase resolution
return start_event.elapsed_time(end_event) * 1000 / iters, *f(*args, **kwargs)
# %%
with torch.inference_mode():
time_nextfort_flux_fwd, _ = benchmark_torch_function(10,
pipe,
"A tree in the forest",
guidance_scale=0.0,
num_inference_steps=4,
max_sequence_length=256,
)
print(time_nextfort_flux_fwd)
Error Result
8913 del arg704_1
8914 del arg705_1
8915 # Source Nodes: [hidden_states_600, key_98, query_98], Original ATen: [aten._scaled_dot_product_flash_attention, aten._to_copy]
8916 buf960 = extern_kernels.nexfort_cuda_cudnn_scaled_dot_product_attention(buf958, buf959, reinterpret_tensor(buf957, (1, 24, 4352, 128), (13369344, 128, 3072, 1), 0), dropout_p=0.0, is_causal=False, scale=0.08838834764831843, attn_mask=None)
-> 8917 assert_size_stride(buf960, (1, 24, 4352, 128), (13369344, 557056, 128, 1))
8918 del buf957
8919 # Source Nodes: [emb_45], Original ATen: [nexfort_inductor.linear_epilogue]
8920 buf961 = extern_kernels.nexfort_cuda_constrained_linear_epilogue(buf8, reinterpret_tensor(arg708_1, (3072, 9216), (1, 3072), 0), reinterpret_tensor(arg709_1, (1, 9216), (0, 1), 0), epilogue_ops=['add'], epilogue_tensor_args=[True], epilogue_scalar_args=[None], with_bias=True)
AssertionError: expected size 24==24, stride 128==557056 at dim=1
Update nexfort to the latest version and set
os.environ['NEXFORT_FUSE_TIMESTEP_EMBEDDING'] = '0'
os.environ['NEXFORT_FX_FORCE_TRITON_SDPA'] = `1`
@strint would it be possible to use fp8? e.g.
Not yet. We are doing some work on flux.
from onediff.infer_compiler import compile
options = '{"mode": "O3"}' # mode can be O2 or O3
flux_pipe = ...
**flux_pipe.transformer = compile(flux_pipe.transformer, backend="nexfort", options=options)**
@strint Will this also work for Flux Dev model?
from onediff.infer_compiler import compile options = '{"mode": "O3"}' # mode can be O2 or O3 flux_pipe = ... **flux_pipe.transformer = compile(flux_pipe.transformer, backend="nexfort", options=options)**
@strint Will this also work for Flux Dev model?
Yes
Anyone gotten it working in comfyui?
backend='nexfort' raised:
CompilationError: at 47:12:
tmp24 = tmp21 * tmp7
tmp25 = tmp24 + tmp9
tmp26 = 47 + ((-1)*x2)
tmp27 = tmp26.to(tl.float32)
tmp28 = tmp27 * tmp7
tmp29 = 47.0
tmp30 = tmp29 - tmp28
tmp31 = tl.where(tmp23, tmp25, tmp30)
tmp32 = tmp31.to(tl.float32)
tmp33 = tl.where(tmp19, tmp32, tmp9)
tmp34 = tl.where(tmp2, tmp17, tmp33)
tmp36 = tmp35.to(tl.float32)
import torch
from onediff.infer_compiler import compile
import os
class TorchCompileModel:
@classmethod
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
"backend": (["inductor", "cudagraphs"],),
}}
RETURN_TYPES = ("MODEL",)
FUNCTION = "patch"
CATEGORY = "_for_testing"
EXPERIMENTAL = True
def patch(self, model, backend):
os.environ['NEXFORT_FX_FORCE_TRITON_SDPA'] = '1'
os.environ['NEXFORT_FUSE_TIMESTEP_EMBEDDING'] = '0'
options = '{"mode": "O3"}' # mode can be O2 or O3
m = model.clone()
#m.add_object_patch("diffusion_model", torch.compile(model=m.get_model_object("diffusion_model"), backend=backend))
m.add_object_patch("diffusion_model", compile(m.get_model_object("diffusion_model"), backend="nexfort", options=options))
return (m, )
NODE_CLASS_MAPPINGS = {
"TorchCompileModel": TorchCompileModel,