siliconflow/onediff

Add acceleration support for FLUX models

oreasono opened this issue ยท 11 comments

๐Ÿš€ The feature, motivation and pitch

https://blackforestlabs.ai/#get-flux
FLUX models are the new SOTA opensource text-to-image model. I am wondering if this slightly different architecture model can still gain benifit from onediff framework.

Alternatives

diffusers library has a FluxPipeline support https://github.com/black-forest-labs/flux?tab=readme-ov-file#diffusers-integration

Additional context

No response

bump I am also looking for this

We are working on FLUX related optimization.

If you are interested to try now, you can try this, it should works: https://github.com/siliconflow/onediff/tree/main/src/onediff/infer_compiler

from onediff.infer_compiler import compile

# module is the model you want to compile
options = '{"mode": "O3"}'  # mode can be O2 or O3
flux_pipe = ...
flux_pipe.transformer = compile(flux_pipe.transformer, backend="nexfort", options=options)

We are working on FLUX related optimization.

If you are interested to try now, you can try this, it should works: https://github.com/siliconflow/onediff/tree/main/src/onediff/infer_compiler

from onediff.infer_compiler import compile

# module is the model you want to compile
options = '{"mode": "O3"}'  # mode can be O2 or O3
flux_pipe = ...
flux_pipe.transformer = compile(flux_pipe.transformer, backend="nexfort", options=options)

I keep having this error: RuntimeError: RuntimeError: Unsupported timesteps dtype: c10::BFloat16

I keep having this error: RuntimeError: RuntimeError: Unsupported timesteps dtype: c10::BFloat16

update nexfort, and then set these environment variables:

    os.environ['NEXFORT_FUSE_TIMESTEP_EMBEDDING'] = '0'

Hi, I haven't been able to run FLUX diffusers pipeline with nextfort compiler backend. I'm using the settings discussed in this thread, and also based on the official configurations. @strint any suggestion?

Environnent

Torch version: 2.3.0
CUDA version: 12.1.0
GPU: NVIDIA A100-SXM4-80GB
Nexfort version: 0.1.dev264
Onediff/onediffx: 1.2.1.dev18+g6b53a83b Build from source for dev
Diffusers version: 0.30.0

Test Script

# %%
import os
from onediffx import compile_pipe
from onediff.infer_compiler import compile
import torch
from torch.utils.flop_counter import FlopCounterMode
from triton.testing import do_bench
from diffusers.pipelines.flux.pipeline_flux import FluxPipeline

torch.set_default_device('cuda')
# RuntimeError: RuntimeError: Unsupported timesteps dtype: c10::BFloat16
# ref: https://github.com/siliconflow/onediff/issues/1066#issuecomment-2271523799
os.environ['NEXFORT_FUSE_TIMESTEP_EMBEDDING'] = '0'
# %%
model_id: str = "black-forest-labs/FLUX.1-schnell"

pipe = FluxPipeline.from_pretrained(model_id, 
                                    torch_dtype=torch.bfloat16, 
                                    )
pipe.to("cuda")
# %%
"""
options = '{"mode": "O3"}' 
pipe.transformer = compile(pipe.transformer, backend="nexfort", options=options)
"""
# compiler optimization options: '{"mode": "O3", "memory_format": "channels_last"}'
options = '{"mode": "O2"}'
pipe = compile_pipe(pipe, backend="nexfort", options=options, fuse_qkv_projections=True)
# %%
prompt = "A cat holding a sign that says hello world"
image = pipe(
    prompt,
    guidance_scale=0.0,
    num_inference_steps=4,
    max_sequence_length=256,
    generator=torch.Generator("cpu").manual_seed(0)
).images[0]
image.save("flux-schnell.png")
# %%
def get_flops_achieved(f):
    flop_counter = FlopCounterMode(display=False)
    with flop_counter:
        f()
    total_flops = flop_counter.get_total_flops()
    ms_per_iter = do_bench(f)
    iters_per_second = 1e3/ms_per_iter
    print(f"{iters_per_second * total_flops / 1e12} TF/s")


get_flops_achieved(lambda: pipe(
                "A tree in the forest",
                guidance_scale=0.0,
                num_inference_steps=4,
                max_sequence_length=256,
                generator=torch.Generator("cpu").manual_seed(0)
))
# %%
def benchmark_torch_function(iters, f, *args, **kwargs):
    f(*args, **kwargs)
    f(*args, **kwargs)
    torch.cuda.synchronize()
    start_event = torch.cuda.Event(enable_timing=True)
    end_event = torch.cuda.Event(enable_timing=True)
    start_event.record()
    for _ in range(iters):
        f(*args, **kwargs)
    end_event.record()
    torch.cuda.synchronize()
    # elapsed_time has a resolution of 0.5 microseconds:
    # but returns milliseconds, so we need to multiply it to increase resolution
    return start_event.elapsed_time(end_event) * 1000 / iters, *f(*args, **kwargs)
# %%
with torch.inference_mode():
    time_nextfort_flux_fwd, _ = benchmark_torch_function(10,
                                                        pipe,
                                                        "A tree in the forest",
                                                        guidance_scale=0.0,
                                                        num_inference_steps=4,
                                                        max_sequence_length=256,
                                                        )
print(time_nextfort_flux_fwd)

Error Result

   8913         del arg704_1
   8914         del arg705_1
   8915         # Source Nodes: [hidden_states_600, key_98, query_98], Original ATen: [aten._scaled_dot_product_flash_attention, aten._to_copy]
   8916         buf960 = extern_kernels.nexfort_cuda_cudnn_scaled_dot_product_attention(buf958, buf959, reinterpret_tensor(buf957, (1, 24, 4352, 128), (13369344, 128, 3072, 1), 0), dropout_p=0.0, is_causal=False, scale=0.08838834764831843, attn_mask=None)
-> 8917         assert_size_stride(buf960, (1, 24, 4352, 128), (13369344, 557056, 128, 1))
   8918         del buf957
   8919         # Source Nodes: [emb_45], Original ATen: [nexfort_inductor.linear_epilogue]
   8920         buf961 = extern_kernels.nexfort_cuda_constrained_linear_epilogue(buf8, reinterpret_tensor(arg708_1, (3072, 9216), (1, 3072), 0), reinterpret_tensor(arg709_1, (1, 9216), (0, 1), 0), epilogue_ops=['add'], epilogue_tensor_args=[True], epilogue_scalar_args=[None], with_bias=True)

AssertionError: expected size 24==24, stride 128==557056 at dim=1

Update nexfort to the latest version and set

os.environ['NEXFORT_FUSE_TIMESTEP_EMBEDDING'] = '0'
os.environ['NEXFORT_FX_FORCE_TRITON_SDPA'] = `1`

@antferdom

@strint would it be possible to use fp8? e.g.

@strint would it be possible to use fp8? e.g.

Not yet. We are doing some work on flux.

from onediff.infer_compiler import compile


options = '{"mode": "O3"}'  # mode can be O2 or O3
flux_pipe = ...
**flux_pipe.transformer = compile(flux_pipe.transformer, backend="nexfort", options=options)**

@strint Will this also work for Flux Dev model?

from onediff.infer_compiler import compile


options = '{"mode": "O3"}'  # mode can be O2 or O3
flux_pipe = ...
**flux_pipe.transformer = compile(flux_pipe.transformer, backend="nexfort", options=options)**

@strint Will this also work for Flux Dev model?

Yes

Anyone gotten it working in comfyui?

backend='nexfort' raised:
CompilationError: at 47:12:
tmp24 = tmp21 * tmp7
tmp25 = tmp24 + tmp9
tmp26 = 47 + ((-1)*x2)
tmp27 = tmp26.to(tl.float32)
tmp28 = tmp27 * tmp7
tmp29 = 47.0
tmp30 = tmp29 - tmp28
tmp31 = tl.where(tmp23, tmp25, tmp30)
tmp32 = tmp31.to(tl.float32)
tmp33 = tl.where(tmp19, tmp32, tmp9)
tmp34 = tl.where(tmp2, tmp17, tmp33)
tmp36 = tmp35.to(tl.float32)

import torch
from onediff.infer_compiler import compile
import os

class TorchCompileModel:
    @classmethod
    def INPUT_TYPES(s):
        return {"required": { "model": ("MODEL",),
                             "backend": (["inductor", "cudagraphs"],),
                              }}
    RETURN_TYPES = ("MODEL",)
    FUNCTION = "patch"

    CATEGORY = "_for_testing"
    EXPERIMENTAL = True

    def patch(self, model, backend):
        os.environ['NEXFORT_FX_FORCE_TRITON_SDPA'] = '1'
        os.environ['NEXFORT_FUSE_TIMESTEP_EMBEDDING'] = '0'
        options = '{"mode": "O3"}'  # mode can be O2 or O3
        m = model.clone()
        #m.add_object_patch("diffusion_model", torch.compile(model=m.get_model_object("diffusion_model"), backend=backend))
        m.add_object_patch("diffusion_model", compile(m.get_model_object("diffusion_model"), backend="nexfort", options=options))
        return (m, )

NODE_CLASS_MAPPINGS = {
    "TorchCompileModel": TorchCompileModel,