aredden/flux-fp8-api

Acceleration not as expected

alecyan1993 opened this issue · 1 comments

Hi,

Thanks so much for your work. I'm using the H100 for the experiment on the acceleration, however I can only achieve around 6it/s for the Flux-dev inference. Here's my config file:

{
  "version": "flux-dev",
  "params": {
    "in_channels": 64,
    "vec_in_dim": 768,
    "context_in_dim": 4096,
    "hidden_size": 3072,
    "mlp_ratio": 4.0,
    "num_heads": 24,
    "depth": 19,
    "depth_single_blocks": 38,
    "axes_dim": [
      16,
      56,
      56
    ],
    "theta": 10000,
    "qkv_bias": true,
    "guidance_embed": true
  },
  "ae_params": {
    "resolution": 256,
    "in_channels": 3,
    "ch": 128,
    "out_ch": 3,
    "ch_mult": [
      1,
      2,
      4,
      4
    ],
    "num_res_blocks": 2,
    "z_channels": 16,
    "scale_factor": 0.3611,
    "shift_factor": 0.1159
  },
  "ckpt_path": "/root/flux-fp8-api/flux1-dev.sft",
  "ae_path": "/root/flux-fp8-api/ae.sft",
  "repo_id": "black-forest-labs/FLUX.1-dev",
  "repo_flow": "flux1-dev.sft",
  "repo_ae": "ae.sft",
  "text_enc_max_length": 512,
  "text_enc_path": "city96/t5-v1_1-xxl-encoder-bf16",
  "text_enc_device": "cuda:0",
  "ae_device": "cuda:0",
  "flux_device": "cuda:0",
  "flow_dtype": "float16",
  "ae_dtype": "bfloat16",
  "text_enc_dtype": "bfloat16",
  "text_enc_quantization_dtype": "qfloat8",
  "ae_quantization_dtype": "qfloat8",
  "compile_extras": true,
  "compile_blocks": true,
  "offload_ae": false,
  "offload_text_enc": false,
  "offload_flow": false
}

Here's the demo file:

import io
from flux_pipeline import FluxPipeline
import torch
import time

pipe = FluxPipeline.load_pipeline_from_config_path(
    "configs/config-dev-prequant.json"  # or whatever your config is
)

# compile model
pipe.model.to(memory_format=torch.channels_last)
pipe.model = torch.compile(pipe.model)

for i in range(10):
    output_jpeg_bytes: io.BytesIO = pipe.generate(
        # Required args:
        prompt="A beautiful asian woman in traditional clothing with golden hairpin and blue eyes, wearing a red kimono with dragon patterns",
        # Optional args:
        width=1024,
        height=1024,
        num_steps=20,
        guidance=3.5,
        seed=13456,
        strength=0.8,
    )

Would you please have a look and see if there's anything mistake that I might make? Thanks!

That is interesting. I would try maybe using data type bfloat16 for the flow_dtype, since otherwise it'll be using my torch-cublas-hgemm which only really gets speed ups for consumer gpus. Also- you shouldn't compile the model before inference. I would recommend just letting the model get compiled on it's own, since if you set 'compile blocks' and 'compile extras' to true (which you have), it will get compiled on it's own.