argmaxinc/DiffusionKit

FLUX LoRA Inference Support

Opened this issue · 3 comments

This should:

  • Define a canonical LoRA format (checkpoint structure)
  • Implement LoRA checkpoint adapters for top-2 "source" LoRA formats into DiffusionKit's "target" (canonical) LoRA format
  • Allow for fusing LoRA weights into base weights as an optional argument
  • Document the process for adding checkpoint adapters for additional source implementations

The top-2 most used implementations is subjective. We can take the most downloaded LoRAs on Hugging Face as proxy. For example: [sorted by downloads]:

Screenshot 2024-09-10 at 11 17 37 AM

Exploration kickstarter code:

from huggingface_hub import hf_hub_download
from safetensors import safe_open

def load_model(path):
    tensors = {}
    with safe_open(path, framework="pt", device="cpu") as f:
        for key in f.keys():
            tensors[key] = f.get_tensor(key)
    return tensors

TESTED_LORAS = [
    {"repo": "XLabs-AI/flux-RealismLora", "rel_path": "lora.safetensors"},
    {"repo": "ByteDance/Hyper-SD", "rel_path": "Hyper-FLUX.1-dev-16steps-lora.safetensors"},
]

for lora in TESTED_LORAS:
    ckpt_path = hf_hub_download(repo_id=lora["repo"], filename=lora["rel_path"])
    ckpt = load_model(ckpt_path)
    print(len(ckpt), {k:v.shape for k,v in ckpt.items()})

cc: @raoulritter who expressed interest in this earlier

Additional resource: This issue is a great resource for the challenges of dealing with many source implementations with helpful pointers.