hidet-org/hidet

[Bug] Stable Diffusion model compilation

alexeigor opened this issue · 2 comments

Describe the bug
BackendCompilerFailed: hidet_backend raised NotImplementedError: hidet: Tensor.to(..., device=...) is not supported
for symbolic tensors., occurred when calling tensor_to with
args: (<hidet.Tensor object at 0x7f5bbe05e830>, device(type='cuda', index=0))
kwargs: {}
tensor_to is defined at

To Reproduce
https://colab.research.google.com/drive/1XcQ2JQk8-3QoikWyUjO6umtH3z_eiVql#scrollTo=oK2NFZu0BZXJ

Enviroment

  • Google Colab
  • T4 GPU

HI @alexeigor, thanks for trying hidet out!

I was able to reproduce the issue with the colab provided above, the error seems to come from this line
https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unet_2d_condition.py#L675

You can safely insert a line
torch._dynamo.config.suppress_errors = True
before the torch.compile function to temporarily circumvent this issue, while we try to get a proper fix. This error does not affect the functionality of the compiled model, and all the other operators seem to compile fine.

Also, we have not tested channels_last with hidet so you might want to disable that just to be safe. (it may compile but produce weird outputs). We will make sure this is tested very soon.

Finally, to get the best performance from torch.compile(..., backend=hidet), you might want to introduce these configs before issuing the compile function:

# more search 
hidet.torch.dynamo_config.search_space(2)
# automatically transform the model to use float16 data type
hidet.torch.dynamo_config.use_fp16(True)
# use float16 data type as the accumulate data type in operators with reduction
hidet.torch.dynamo_config.use_fp16_reduction(True)
# use tensorcore
hidet.torch.dynamo_config.use_tensor_core()

With search space 2 the model might take a long time to compile due to the need for auto-tuning. Running this with a local notebook with stronger CPUs would be easier.