Update examples to use onediffx
Amitg1 opened this issue · 1 comments
Amitg1 commented
Update examples to use onediffx
specifically
text_to_image_sdxl_reuse_pipe
lines 126-133 - how to load load_state_dict into specific components when using
from onediff.infer_compiler import oneflow_compile
instead of
compiled_unet = oneflow_compile(base.unet)
base.unet = compiled_unet
for each component?
thank you
strint commented
# Update the unet and vae
# load_state_dict(state_dict, strict=True, assign=False), assign is False means copying them inplace into the module’s current parameters and buffers.
# Reference: https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.load_state_dict
print("Loading state_dict of new base into compiled graph")
compiled_unet._torch_module.load_state_dict(new_base.unet.state_dict())
compiled_decoder._torch_module.load_state_dict(new_base.vae.decoder.state_dict())
new_base.unet = compiled_unet
new_base.vae.decoder = compiled_decoder
Because state_dict update needs to specify by the user which of these components should be updated, so we haven't provided a pipeline-level API for this.
You can define one and it should be like this:
new_pipe = reuse_pipe_components(new_pipe, pipe)
def reuse_pipe_components(new_pipe, pipe):
pipe_c = pipe.components()
for k, v in pipe_c:
if hasattr(new_pipe, k):
print(f"update {k}")
nc = getattr(new_pipe, k)
if hasattr(v, "_torch_module"):
v._torch_module.load_state_dict(nc.state_dict())
else:
v.load_state_dict(nc.state_dict())
setattr(new_pipe, k, v)
return new_pipe