facebookincubator/AITemplate

Stable Diffusion demo fails to complile VAE

AlphaAtlas opened this issue · 7 comments

I am trying to compile Stable Diffusion 1.5 (instead of version 2.0 which the demo uses) and am consistently hitting an error with the VAE:

...
2023-05-24 19:15:51,817 INFO <aitemplate.backend.builder> Using 16 CPU for building
2023-05-24 19:23:51,857 INFO <aitemplate.compiler.compiler> compiled the final .so file elapsed time: 0:08:00.038873
[19:23:52] model_container.cu:67: Device Runtime Version: 12010; Driver Version: 12010
[19:23:52] model_container.cu:81: Hardware accelerator device properties:
  Device:
     ASCII string identifying device: NVIDIA GeForce RTX 2060 with Max-Q Design
     Major compute capability: 7
     Minor compute capability: 5
     UUID: GPU-414a7c6b-aad7-20a5-c60d-8cf0708406ff
     Unique identifier for a group of devices on the same multi-GPU board: 0
     PCI bus ID of the device: 1
     PCI device ID of the device: 0
     PCI domain ID of the device: 0
  Memory limits:
     Constant memory available on device in bytes: 65536
     Global memory available on device in bytes: 6214516736
     Size of L2 cache in bytes: 3145728
     Shared memory available per block in bytes: 49152
     Shared memory available per multiprocessor in bytes: 65536
[19:23:52] model_container.cu:85: Init AITemplate Runtime with 1 concurrency
╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮
│ /home/alpha/clone/AITemplate/examples/05_stable_diffusion/scripts/compile_alt.py:138 in <module> │
│                                                                                                  │
│   135                                                                                            │
│   136                                                                                            │
│   137 if __name__ == "__main__":                                                                 │
│ ❱ 138 │   compile_diffusers()                                                                    │
│   139                                                                                            │
│                                                                                                  │
│ /usr/lib/python3.11/site-packages/click/core.py:1130 in __call__                                 │
│                                                                                                  │
│   1127 │                                                                                         │
│   1128 │   def __call__(self, *args: t.Any, **kwargs: t.Any) -> t.Any:                           │
│   1129 │   │   """Alias for :meth:`main`."""                                                     │
│ ❱ 1130 │   │   return self.main(*args, **kwargs)                                                 │
│   1131                                                                                           │
│   1132                                                                                           │
│   1133 class Command(BaseCommand):                                                               │
│                                                                                                  │
│ /usr/lib/python3.11/site-packages/click/core.py:1055 in main                                     │
│                                                                                                  │
│   1052 │   │   try:                                                                              │
│   1053 │   │   │   try:                                                                          │
│   1054 │   │   │   │   with self.make_context(prog_name, args, **extra) as ctx:                  │
│ ❱ 1055 │   │   │   │   │   rv = self.invoke(ctx)                                                 │
│   1056 │   │   │   │   │   if not standalone_mode:                                               │
│   1057 │   │   │   │   │   │   return rv                                                         │
│   1058 │   │   │   │   │   # it's not safe to `ctx.exit(rv)` here!                               │
│                                                                                                  │
│ /usr/lib/python3.11/site-packages/click/core.py:1404 in invoke                                   │
│                                                                                                  │
│   1401 │   │   │   echo(style(message, fg="red"), err=True)                                      │
│   1402 │   │                                                                                     │
│   1403 │   │   if self.callback is not None:                                                     │
│ ❱ 1404 │   │   │   return ctx.invoke(self.callback, **ctx.params)                                │
│   1405 │                                                                                         │
│   1406 │   def shell_complete(self, ctx: Context, incomplete: str) -> t.List["CompletionItem"]:  │
│   1407 │   │   """Return a list of completions for the incomplete value. Looks                   │
│                                                                                                  │
│ /usr/lib/python3.11/site-packages/click/core.py:760 in invoke                                    │
│                                                                                                  │
│    757 │   │                                                                                     │
│    758 │   │   with augment_usage_errors(__self):                                                │
│    759 │   │   │   with ctx:                                                                     │
│ ❱  760 │   │   │   │   return __callback(*args, **kwargs)                                        │
│    761 │                                                                                         │
│    762 │   def forward(                                                                          │
│    763 │   │   __self, __cmd: "Command", *args: t.Any, **kwargs: t.Any  # noqa: B902             │
│                                                                                                  │
│ /home/alpha/clone/AITemplate/examples/05_stable_diffusion/scripts/compile_alt.py:126 in          │
│ compile_diffusers                                                                                │
│                                                                                                  │
│   123 │   │   controlnet=True if controlnet else False,                                          │
│   124 │   )                                                                                      │
│   125 │   # VAE                                                                                  │
│ ❱ 126 │   compile_vae(                                                                           │
│   127 │   │   pipe.vae,                                                                          │
│   128 │   │   batch_size=batch_size,                                                             │
│   129 │   │   width=width,                                                                       │
│                                                                                                  │
│ /home/alpha/clone/AITemplate/examples/05_stable_diffusion/src/compile_lib/compile_vae_alt.py:146 │
│ in compile_vae                                                                                   │
│                                                                                                  │
│   143 │   ait_vae.name_parameter_tensor()                                                        │
│   144 │                                                                                          │
│   145 │   pt_mod = pt_mod.eval()                                                                 │
│ ❱ 146 │   params_ait = map_vae_params(ait_vae, pt_mod)                                           │
│   147 │                                                                                          │
│   148 │   Y = ait_vae.decode(ait_input)                                                          │
│   149 │   mark_output(Y)                                                                         │
│                                                                                                  │
│ /home/alpha/clone/AITemplate/examples/05_stable_diffusion/src/compile_lib/compile_vae_alt.py:58  │
│ in map_vae_params                                                                                │
│                                                                                                  │
│    55 │   │   elif name.endswith("attention.proj_q.weight"):                                     │
│    56 │   │   │   prefix = name[: -len("attention.proj_q.weight")]                               │
│    57 │   │   │   pt_name = prefix + "query.weight"                                              │
│ ❱  58 │   │   │   mapped_pt_params[ait_name] = pt_params[pt_name]                                │
│    59 │   │   elif name.endswith("attention.proj_q.bias"):                                       │
│    60 │   │   │   prefix = name[: -len("attention.proj_q.bias")]                                 │
│    61 │   │   │   pt_name = prefix + "query.bias"                                                │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
KeyError: 'decoder.mid_block.attentions.0.query.weight'

~/clone/AITemplate/examples/05_stable_diffusion main* 9m 24s

My environment from the pytorch data collector:

Collecting environment information...
PyTorch version: 2.1.0.dev20230524+cu121
Is debug build: False
CUDA used to build PyTorch: 12.1
ROCM used to build PyTorch: N/A

OS: CachyOS (x86_64)
GCC version: (GCC) 13.1.1 20230504
Clang version: 15.0.7
CMake version: version 3.26.4
Libc version: glibc-2.37

Python version: 3.11.3 (main, May  4 2023, 16:07:26) [GCC 13.1.1 20230429] (64-bit runtime)
Python platform: Linux-6.3.2-zen1-1.1-zen-x86_64-with-glibc2.37
Is CUDA available: True
CUDA runtime version: 12.1.105
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: GPU 0: NVIDIA GeForce RTX 2060 with Max-Q Design
Nvidia driver version: 530.41.03
cuDNN version: Probably one of the following:
/usr/lib/libcudnn.so.8.8.0
/usr/lib/libcudnn_adv_infer.so.8.8.0
/usr/lib/libcudnn_adv_train.so.8.8.0
/usr/lib/libcudnn_cnn_infer.so.8.8.0
/usr/lib/libcudnn_cnn_train.so.8.8.0
/usr/lib/libcudnn_ops_infer.so.8.8.0
/usr/lib/libcudnn_ops_train.so.8.8.0
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture:                    x86_64
CPU op-mode(s):                  32-bit, 64-bit
Address sizes:                   44 bits physical, 48 bits virtual
Byte Order:                      Little Endian
CPU(s):                          16
On-line CPU(s) list:             0-15
Vendor ID:                       AuthenticAMD
Model name:                      AMD Ryzen 9 4900HS with Radeon Graphics
CPU family:                      23
Model:                           96
Thread(s) per core:              2
Core(s) per socket:              8
Socket(s):                       1
Stepping:                        1
Frequency boost:                 enabled
CPU(s) scaling MHz:              66%
CPU max MHz:                     3000.0000
CPU min MHz:                     1400.0000
BogoMIPS:                        5988.67
Flags:                           fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good nopl nonstop_tsc cpuid extd_apicid aperfmperf rapl pni pclmulqdq monitor ssse3 fma cx16 sse4_1 sse4_2 movbe popcnt aes xsave avx f16c rdrand lahf_lm cmp_legacy svm extapic cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw ibs skinit wdt tce topoext perfctr_core perfctr_nb bpext perfctr_llc mwaitx cpb cat_l3 cdp_l3 hw_pstate ssbd mba ibrs ibpb stibp vmmcall fsgsbase bmi1 avx2 smep bmi2 cqm rdt_a rdseed adx smap clflushopt clwb sha_ni xsaveopt xsavec xgetbv1 cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local clzero irperf xsaveerptr rdpru wbnoinvd cppc arat npt lbrv svm_lock nrip_save tsc_scale vmcb_clean flushbyasid decodeassists pausefilter pfthreshold avic v_vmsave_vmload vgif v_spec_ctrl umip rdpid overflow_recov succor smca sev sev_es
Virtualization:                  AMD-V
L1d cache:                       256 KiB (8 instances)
L1i cache:                       256 KiB (8 instances)
L2 cache:                        4 MiB (8 instances)
L3 cache:                        8 MiB (2 instances)
NUMA node(s):                    1
NUMA node0 CPU(s):               0-15
Vulnerability Itlb multihit:     Not affected
Vulnerability L1tf:              Not affected
Vulnerability Mds:               Not affected
Vulnerability Meltdown:          Not affected
Vulnerability Mmio stale data:   Not affected
Vulnerability Retbleed:          Vulnerable
Vulnerability Spec store bypass: Vulnerable
Vulnerability Spectre v1:        Vulnerable: __user pointer sanitization and usercopy barriers only; no swapgs barriers
Vulnerability Spectre v2:        Vulnerable, IBPB: disabled, STIBP: disabled, PBRSB-eIBRS: Not affected
Vulnerability Srbds:             Not affected
Vulnerability Tsx async abort:   Not affected

Versions of relevant libraries:
[pip3] clip-anytorch==2.5.2
[pip3] mypy-extensions==1.0.0
[pip3] numpy==1.24.3
[pip3] pytorch-lightning==2.0.2
[pip3] pytorch-triton==2.1.0+7d1a95b046
[pip3] torch==2.1.0.dev20230524+cu121
[pip3] torchdiffeq==0.2.3
[pip3] torchmetrics==1.0.0rc0
[pip3] torchsde==0.2.5
[pip3] torchvision==0.16.0.dev20230524+cu121
[conda] Could not collect

The full console log: log.txt

I tried a existing (otherwise usable) model and a fresh download of 1.5 with the included script.

Going to try the default demo model and see if it still fails...

The command I used for compilation is python scripts/compile_alt.py --local-dir /home/alpha/Storage/AIModels/Stable-diffusion/stable-diffusion-2 --height 384 768 --width 384 768 --batch-size 1 2 --include-constants True

I cleared my HF cache just in case, but the default model (2.0) hits the same error.

...
│ /home/alpha/clone/AITemplate/examples/05_stable_diffusion/src/compile_lib/compile_vae_alt.py:58  │
│ in map_vae_params                                                                                │
│                                                                                                  │
│    55 │   │   elif name.endswith("attention.proj_q.weight"):                                     │
│    56 │   │   │   prefix = name[: -len("attention.proj_q.weight")]                               │
│    57 │   │   │   pt_name = prefix + "query.weight"                                              │
│ ❱  58 │   │   │   mapped_pt_params[ait_name] = pt_params[pt_name]                                │
│    59 │   │   elif name.endswith("attention.proj_q.bias"):                                       │
│    60 │   │   │   prefix = name[: -len("attention.proj_q.bias")]                                 │
│    61 │   │   │   pt_name = prefix + "query.bias"                                                │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
KeyError: 'decoder.mid_block.attentions.0.query.weight'
hlky commented

Fix for LDM VAE mapping #724

Works now, thanks for the quick fix 👍

Oh that's strange. The PR fixed this issue, and I just reinstalled AITemplate main with a fresh clone since the commit was merged, but now I am getting the exact same error. Was it really merged?

Hmmm, I guess I will try this PR again and see.

EDIT: On second thought this could be an environment issue too, I need to check...

@AlphaAtlas #724 was merged in 89711d9 4 days ago: it should be a part of the main branch now. E.g., I see the updated code here in the main branch. Please double check and let me know if it works. Thanks!

It is indeed an environment issue, thanks 👍