microsoft/webnn-developer-preview

SD Turbo unet details

mateusz-malicki opened this issue ยท 8 comments

Can you provide a bit more details than "model has been optimized to work with WebNN" regarding convertion of SD unet to ONNX to work nicely with webNN?
Conversion script would be great, but hints in any form would be appreciated.
Thanks!

The original model is provided by @fdwr, he may have more voice.

The original model is dynamic shape and ORT-Web provides various graph optimizations to improve performance. These can be applied during ort session creation via graphOptimizationLevel option, but which will bring much more overhead to device memory and creation time.

So we firstly used onnxruntime_perf_test tool to optimize the model to static shape and apply Level 1 graph optimization (higher level optimization is not supported by WebNN at present), then dumped the model to local as the target test model.

fdwr commented

Sorry it's been too many months since I converted the model that I don't have the exact steps anymore, but IIRC:

  1. Start with float32 unet model (starting with the float16 version will fail due to running out of memory while constant folding and/or Protobuf overflows due to 2GB Protobuf limit) using a model with a separate weights.pb file like this one.
  2. Fix the model input sizes to typical static values (batch=2, channels=4, latent width=64, latent height=64, sequence=77...). I edited unet's model.onnx file as prototxt via Onnx2Text to update the graph/input/type/shape/dim/value's.
  3. Optimize the model to fold any constants and prune unused nodes via onnxruntime_perf_test.exe (again use float32 model, not the original float16 model directly yet, which inserts extra casts otherwise and encounters missing operator issues).
    • Set optimization to 0 or 1, as higher levels overflow the Protobuf 4GB limit.
    • Note passing "-f" for free dimension overrides should not even be necessary if you already fixed the input shapes earlier.
    • onnxruntime_perf_test.exe -o 0 -r 1 -f batch:2 -f channels:4 -f height:64 -f width:64 -f sequence:77 -I -u Stable-Diffusion-v1.5-unet-float16-static-sample(2,4,64,64)-timestep(2)-encoder(2,77,768)-embedded-weights.onnx" "Stable-Diffusion-v1.5-unet-float32-static-sample(2,4,64,64)-timestep(2)-encoder(2,77,768).onnx"
  4. Convert float32 to float16 (also applies shape inference and embeds the weights into the .onnx file):

Pip freeze:

onnx==1.13.1
onnxconverter-common==1.9.0
onnxoptimizer==0.2.7
onnxruntime==1.14.1

ConvertToFloat16.py (given input .onnx filename, it writes a new output .onnx file with float16 suffix)

import onnx
import os
import sys
from onnxconverter_common import float16

saveWeightsExternally = False

if len(sys.argv) <= 1:
    print("Pass an ONNX filename.")
    quit()

# Add a filename suffix of "float16".
filePath = sys.argv[1]
filePathSplitExtension = os.path.splitext(filePath)
filePathNoExtension = filePathSplitExtension[0]
fileNameExtension = filePathSplitExtension[1]
fileName = os.path.basename(filePathNoExtension)
fileSuffixSeparator = '-'
if ('_' in fileName) and not ('-' in fileName):
    fileSuffixSeparator = '_'
newFilePath = filePathNoExtension + fileSuffixSeparator + "float16" + fileNameExtension
newWeightsFilename = fileName + fileSuffixSeparator + "float16" + ".weights.pb"

print("Input file: ", filePath)
print("Output file:", newFilePath)

print("Loading input model")
model = onnx.load(filePath)
print("Applying shape inference")
onnx.shape_inference.infer_shapes_path(model_path = filePath, output_path = newFilePath)
print("Reloading input model with inferred shapes")
shapedModel = onnx.load(newFilePath)

print("Converting model to float16")
modelFloat16 = float16.convert_float_to_float16(shapedModel, keep_io_types=False, disable_shape_infer=False)

if saveWeightsExternally:
    print("Saving output model to " + newFilePath + " and " + newWeightsFilename)
else:
    print("Saving output model to " + newFilePath)

onnx.save_model(modelFloat16, newFilePath, save_as_external_data=saveWeightsExternally, all_tensors_to_one_file=True, location=newWeightsFilename)

Thanks, this is very helpful! I still have a small question regarding casting fp16<->fp32 around InstanceNormalization, was it edited manually as in point 2, but on model already saved in fp16?

fdwr commented

regarding casting fp16<->fp32 around InstanceNormalization

๐Ÿค” I didn't do anything special for InstanceNormalization, but I recall @Honry mentioned some precision issues related to it. Wanming?

regarding casting fp16<->fp32 around InstanceNormalization

๐Ÿค” I didn't do anything special for InstanceNormalization, but I recall @Honry mentioned some precision issues related to it. Wanming?

Exactly, I did cast fp16<->fp32 for InstanceNormalization to fix precision issue.

I used @fdwr's awesome tool: https://github.com/fdwr/Onnx2Text to convert the model to text and manually insert the Cast ops.

Thanks a lot for your time guys, now its clear!

are the details for the VAE the same? i am asking cause i want to convert the encoder...

fdwr commented

are the details for the VAE the same? i am asking cause i want to convert the encoder...

@eyaler I didn't convert the VAE encoder (just the decoder) because we didn't have any image-to-image demos (just text to image), but yes, it should be.