siliconflow/onediff

Update examples to use onediffx

Amitg1 opened this issue · 1 comments

Update examples to use onediffx
specifically
text_to_image_sdxl_reuse_pipe

lines 126-133 - how to load load_state_dict into specific components when using
from onediff.infer_compiler import oneflow_compile
instead of

    compiled_unet = oneflow_compile(base.unet)
    base.unet = compiled_unet

for each component?

thank you

https://github.com/siliconflow/onediff/blob/main/onediff_diffusers_extensions/examples/text_to_image_sdxl_reuse_pipe.py#L126-L133

# Update the unet and vae
# load_state_dict(state_dict, strict=True, assign=False), assign is False means copying them inplace into the module’s current parameters and buffers.
# Reference: https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.load_state_dict
print("Loading state_dict of new base into compiled graph")
compiled_unet._torch_module.load_state_dict(new_base.unet.state_dict())
compiled_decoder._torch_module.load_state_dict(new_base.vae.decoder.state_dict())

new_base.unet = compiled_unet
new_base.vae.decoder = compiled_decoder

Because state_dict update needs to specify by the user which of these components should be updated, so we haven't provided a pipeline-level API for this.

You can define one and it should be like this:

new_pipe = reuse_pipe_components(new_pipe, pipe)
def reuse_pipe_components(new_pipe, pipe):
     pipe_c = pipe.components()
     for k, v in pipe_c:
         if hasattr(new_pipe, k):
             print(f"update {k}")
             nc = getattr(new_pipe, k)
             if hasattr(v, "_torch_module"):
                 v._torch_module.load_state_dict(nc.state_dict())
             else:
                 v.load_state_dict(nc.state_dict())
             setattr(new_pipe, k, v)
    return new_pipe

@Amitg1