facebookincubator/AITemplate

Confused on the shape of input Tensor

ningmenghongcha opened this issue · 7 comments

I compiled stable diffusion.The binary model is very fast compared to pytorch eager.
So,I am going to compile A conditional 3D UNet model,UNet3DConditionModel for text to video generation.
While I am checking the latent sample dimension.
the hugging face noisy input tensor with the following shape (batch, channel, height, width).
While ait input tensor is:

latent_model_input_ait = Tensor(
        [batch_size, height, width, 4], name="input0", is_input=True
    )

bhwc is different from bchw.
My question is why the dims are different?Do I have to follow your input shape?
Since i have to replace a lot of torch.operators in UNet3DConditionModel,I probably will mess up with the different shapes.
Thank you for your attention!

For example,in Torch

...
 self.norm = nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
 self.proj_in = nn.Linear(in_channels, inner_dim)
...

 def forward():
   ...
   # input is (batch size, channel, height, width)
   hidden_states = hidden_states[None, :].reshape(batch_size, num_frames, channel, height, width)
   hidden_states = hidden_states.permute(0, 2, 1, 3, 4)

   hidden_states = self.norm(hidden_states)
   hidden_states = hidden_states.permute(0, 3, 4, 2, 1).reshape(batch_size * height * width, num_frames, channel)

   hidden_states = self.proj_in(hidden_states)
   ...

While in AIT

...
 self.norm = nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
 self.proj_in = nn.Linear(in_channels, inner_dim)
...

 def forward():
   ...
   # input is [batch_size*num_frames, height, width, 4]

   # what shapes do I need?
   hidden_states = self.norm(hidden_states)

   # what shapes do I need?
   hidden_states = self.proj_in(hidden_states)
   ...
hlky commented

AIT uses 'channels last' - bhwc rather than bchw due to the performance gains in operators like conv, pooling and upsampling.

When converting models to AIT we need to account for memory format in the graph itself, inputs, outputs (when verifying etc) and constants.

In common cases inputs, outputs and (some) constants need to be permuted e.g. 0, 2, 3, 1 <-> 0, 3, 1, 2.

I would recommend creating a self-contained development script for an easier workflow, an example that I use is below.

Your specific case appears to be TransformerTemporalModel. Here channels is coming from in_channels=block_out_channels[0] in UNet3DConditionModel. If you have channels==4 at this point of UNet3DConditionModel it would seem there is an issue beforehand.

Using the development workflow, we print shapes at each step of the model for both PyTorch and AIT. For AIT, we might start with a 'basic' version i.e. no permutes etc. Keeping 'channels last' in mind and referring to PyTorch shapes we can then match them at each step, in some cases permute/reshape etc can be eliminated, in others arguments to an operator will need to be changed e.g. dim 1 to dim 3.

Once the AIT module a) compiles and b) shape matches (accounting for channels last) we verify numerical accuracy. AIT output tensors may need to be permuted etc before verification. If verification does not pass, try returning at an earlier point.

Following this process the AIT forward for your model section becomes:

num_frames = 1
batch_size, height, width, channel = ops.size()(hidden_states)
hidden_states = self.norm(hidden_states)
hidden_states = ops.reshape()(
    hidden_states, [batch_size * height * width, num_frames, channel]
)
hidden_states = self.proj_in(hidden_states)

However, please note there is actually a bug in codegen of AIT nn.GroupNorm, it expects 4 dims, the shape before norm should include num_frames which would in turn create numerical inaccuracy due to incorrectly detected parameters of GroupNorm. Further testing for num_frames>1 would be required to determine if this is an issue that needs fixing. In the case num_frames==1 verification passes.

from typing import cast, Iterable

import torch

from aitemplate.compiler import compile_model, Model, ops
from aitemplate.frontend import Tensor, nn
from aitemplate.testing import detect_target


def mark_output(tensor: Tensor, name: str):
    tensor._attrs["is_output"] = True
    tensor._attrs["name"] = name
    shape = [d._attrs["values"] for d in tensor._attrs["shape"]]
    print(f"AIT output `{name}` shape {shape}")
    return tensor


def inference(
    exe_module: Model,
    input_tensors: Iterable[torch.Tensor],
    benchmark: bool = False,
    benchmark_count: int = 25,
    benchmark_repeat: int = 2,
    graph_mode: bool = False,
) -> dict[str, torch.Tensor] | float:
    ys = {}
    for name, idx in exe_module.get_output_name_to_index_map().items():
        shape = exe_module.get_output_maximum_shape(idx)
        ys[name] = torch.empty(shape).cuda().half()
    if benchmark:
        t, _, _ = exe_module.benchmark_with_tensors(
            input_tensors, ys, count=benchmark_count, repeat=benchmark_repeat
        )
        return t
    else:
        exe_module.run_with_tensors(input_tensors, ys, graph_mode=graph_mode)
        return ys
class PTModel(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()

    def forward(self, hidden_states: torch.Tensor):
        return hidden_states


class AITModel(nn.Module):
    def __init__(self) -> None:
        super().__init__()

    def forward(self, hidden_states: Tensor):
        return hidden_states
pt_model = PTModel().eval().half().cuda()
ait_model = AITModel()
ait_model.name_parameter_tensor()

input_shape_pt = [1, 320, 64, 64]
x_pt = torch.randn(input_shape_pt).half().cuda()

# permute etc when required
input_shape_ait = [1, 64, 64, 320]
x_ait = x_pt.clone().permute(0, 2, 3, 1).contiguous()
with torch.inference_mode():
    y_pt = pt_model.forward(x_pt)

X_ait = Tensor(input_shape_ait, name="X", is_input=True)
Y_ait = ait_model.forward(X_ait)
Y_ait = mark_output(Y_ait, "Y")
state_dict_pt = cast(dict[str, torch.Tensor], pt_model.state_dict())
state_dict_ait = {}
for key, value in state_dict_pt.items():
    key_ait = key.replace(".", "_")
    value = value.half().cuda().contiguous()
    state_dict_ait[key_ait] = value
target = detect_target()
work_dir = "./tmp"
model_name = "AITModel"
ait_module = compile_model(
    Y_ait,
    target,
    work_dir,
    model_name,
    constants=state_dict_ait,
)
y_ait = inference(ait_module, [x_ait])["Y"]
# permute etc if required

tol = 1e-3
torch.testing.assert_close(
    y_ait,
    y_pt,
    rtol=tol,
    atol=tol,
    msg=lambda msg: f"{msg}\n\npt ({y_pt.shape}):\n{y_pt}\n\nait ({y_ait.shape}):\n{y_ait}\n\n",
)

print("verified")

Thanks for your quick reply.Your explanations are great,I am following your workflow.
Since the first step of Unet3D is time embedding,I am trying to make shape match and verify numerical accuracy.Compiling seems to be okay,however something went wrong when inference with ait_module.

INFO:aitemplate.backend.build_cache_base:Build cache disabled
timesteps shape torch.Size([2])
t_emb shape torch.Size([2, 320])
emb shape torch.Size([2, 1280])
y_pt shape: torch.Size([2, 1280])
2024-06-13 15:14:38,234 INFO <aitemplate.testing.detect_target> Set target to CUDA
timesteps expand shape: [2]
t_emb shape: [2, 320]
emb shape: [2, 1280]
AIT output `Y` shape [[2], [1280]]
2024-06-13 15:14:38,236 INFO <aitemplate.compiler.compiler> Start to compile AIT model. test_dir='./tmp/AITModel'
2024-06-13 15:14:38,236 INFO <aitemplate.backend.target> Loading profile cache from: /root/.aitemplate/cuda.db
2024-06-13 15:14:38,236 INFO <aitemplate.backend.profiler_cache> table_name='cuda_gemm_3' exists in the db
2024-06-13 15:14:38,237 INFO <aitemplate.backend.profiler_cache> table_name='cuda_conv_3' exists in the db
2024-06-13 15:14:38,237 INFO <aitemplate.backend.profiler_cache> table_name='cuda_conv3d_3' exists in the db
2024-06-13 15:14:41,496 INFO <aitemplate.compiler.compiler> optimized graph elapsed time: 0:00:00.024311
2024-06-13 15:14:41,496 INFO <aitemplate.compiler.transform.refine_graph> reduced unique ops from 5 to 5
2024-06-13 15:14:41,496 INFO <aitemplate.compiler.transform.profile> Force profiler cache = False
2024-06-13 15:14:41,723 INFO <aitemplate.compiler.ops.gemm_universal.gemm_common> Load profiling result for gemm_rcr_bias_swish_11 from cache: ('cutlass_tensorop_f16_s16816gemm_f16_64x64_64x5_tn_align_8_8', 0, 1)
2024-06-13 15:14:41,955 INFO <aitemplate.compiler.ops.gemm_universal.gemm_common> Load profiling result for gemm_rcr_bias_12 from cache: ('cutlass_tensorop_f16_s16816gemm_f16_64x64_64x5_tn_align_8_8', 0, 1)
2024-06-13 15:14:41,955 INFO <aitemplate.compiler.transform.profile> generated 0 profilers elapsed time: 0:00:00.459273
2024-06-13 15:14:41,955 INFO <aitemplate.backend.builder> Using 128 CPU for building
2024-06-13 15:14:41,955 INFO <aitemplate.compiler.transform.profile> compiled profilers elapsed time: 0:00:00.000228
2024-06-13 15:14:41,956 INFO <aitemplate.backend.profiler_runner> Initialized profiler runner with devices: [0]
2024-06-13 15:14:41,956 INFO <aitemplate.compiler.ops.gemm_universal.gemm_common> Profile: gemm_rcr_bias_swish_11: M == 2 && N == 1280 && K == 320
2024-06-13 15:14:41,956 INFO <aitemplate.compiler.ops.gemm_universal.gemm_common> Profile: gemm_rcr_bias_12: M == 2 && N == 1280 && K == 1280
2024-06-13 15:14:41,957 INFO <aitemplate.compiler.transform.profile> ran 2 profilers elapsed time: 0:00:00.000962
2024-06-13 15:14:41,957 INFO <aitemplate.compiler.transform.memory_planning> Workspace shared_size=0 unique_size=0
2024-06-13 15:14:41,957 INFO <aitemplate.compiler.transform.memory_planning> max_blob=320 constant_offset=0
2024-06-13 15:14:41,958 INFO <aitemplate.backend.codegen> generated 1 function srcs
2024-06-13 15:14:41,960 INFO <aitemplate.compiler.compiler> folded constants elapsed time: 0:00:00.003156
2024-06-13 15:14:41,960 INFO <aitemplate.compiler.transform.memory_planning> Workspace shared_size=0 unique_size=0
2024-06-13 15:14:41,960 INFO <aitemplate.compiler.transform.memory_planning> max_blob=10240 constant_offset=4101440
2024-06-13 15:14:41,967 INFO <aitemplate.backend.codegen> generated 4 function srcs
2024-06-13 15:14:41,984 INFO <aitemplate.backend.codegen> generated 8 library srcs
2024-06-13 15:14:41,984 INFO <aitemplate.backend.builder> Using 128 CPU for building
2024-06-13 15:15:01,982 INFO <aitemplate.compiler.compiler> compiled the final .so file elapsed time: 0:00:19.997474
[15:15:01] model_container.cu:69: Device Runtime Version: 11080; Driver Version: 12040
[15:15:01] model_container.cu:83: Hardware accelerator device properties: 
  Device: 
     ASCII string identifying device: NVIDIA A800 80GB PCIe
     Major compute capability: 8
     Minor compute capability: 0
     UUID: GPU-cc0f40aa-67d8-2fc4-05e5-ccb90e96019d
     Unique identifier for a group of devices on the same multi-GPU board: 0
     PCI bus ID of the device: 53
     PCI device ID of the device: 0
     PCI domain ID of the device: 0
  Memory limits: 
     Constant memory available on device in bytes: 65536
     Global memory available on device in bytes: 84974239744
     Size of L2 cache in bytes: 41943040
     Shared memory available per block in bytes: 49152
     Shared memory available per multiprocessor in bytes: 167936
[15:15:01] model_container.cu:87: Init AITemplate Runtime with 1 concurrency
[15:15:01] model_interface.cu:221: Error: Constant arange was not set! Set the value with set_constant.
Traceback (most recent call last):
  File "timesteps.py", line 337, in <module>
    y_ait = inference(ait_module, [x0_ait,x1_ait,x2_ait])["Y"]
  File "timesteps.py", line 38, in inference
    exe_module.run_with_tensors(input_tensors, ys, graph_mode=graph_mode)
  File "/usr/local/lib/python3.8/dist-packages/aitemplate/compiler/model.py", line 597, in run_with_tensors
    outputs_ait = self.run(
  File "/usr/local/lib/python3.8/dist-packages/aitemplate/compiler/model.py", line 495, in run
    return self._run_impl(
  File "/usr/local/lib/python3.8/dist-packages/aitemplate/compiler/model.py", line 434, in _run_impl
    self.DLL.AITemplateModelContainerRun(
  File "/usr/local/lib/python3.8/dist-packages/aitemplate/compiler/model.py", line 196, in _wrapped_func
    raise RuntimeError(f"Error in function: {method.__name__}")
RuntimeError: Error in function: AITemplateModelContainerRun

My python code here

from aitemplate.compiler import compile_model, Model, ops
from aitemplate.frontend import Tensor, nn
from aitemplate.testing import detect_target

from typing import Any, Dict, List, Optional, Tuple, Union, cast, Iterable
import torch
import math

def get_shape(x):
    shape = [it.value() for it in x._attrs["shape"]]
    return shape
    
def mark_output(tensor: Tensor, name: str):
    tensor._attrs["is_output"] = True
    tensor._attrs["name"] = name
    shape = [d._attrs["values"] for d in tensor._attrs["shape"]]
    print(f"AIT output `{name}` shape {shape}")
    return tensor

def inference(
    exe_module: Model,
    input_tensors: Iterable[torch.Tensor],
    benchmark: bool = False,
    benchmark_count: int = 25,
    benchmark_repeat: int = 2,
    graph_mode: bool = False,
):
    ys = {}
    for name, idx in exe_module.get_output_name_to_index_map().items():
        shape = exe_module.get_output_maximum_shape(idx)
        ys[name] = torch.empty(shape).cuda().half()
    if benchmark:
        t, _, _ = exe_module.benchmark_with_tensors(
            input_tensors, ys, count=benchmark_count, repeat=benchmark_repeat
        )
        return t
    else:
        exe_module.run_with_tensors(input_tensors, ys, graph_mode=graph_mode)
        return ys

class TimestepEmbeddingTorch(torch.nn.Module):
    def __init__(
        self,
        in_channels: int,
        time_embed_dim: int,
        act_fn: str = "silu",
        out_dim: int = None,
        post_act_fn: Optional[str] = None,
        cond_proj_dim=None,
        sample_proj_bias=True,
    ):
        super().__init__()

        self.linear_1 = torch.nn.Linear(in_channels, time_embed_dim, sample_proj_bias)

        if cond_proj_dim is not None:
            self.cond_proj = torch.nn.Linear(cond_proj_dim, in_channels, bias=False)
        else:
            self.cond_proj = None

        # self.act = get_activation(act_fn)
        self.act = torch.nn.SiLU()

        if out_dim is not None:
            time_embed_dim_out = out_dim
        else:
            time_embed_dim_out = time_embed_dim
        self.linear_2 = torch.nn.Linear(time_embed_dim, time_embed_dim_out, sample_proj_bias)

        if post_act_fn is None:
            self.post_act = None
        else:
            self.post_act = get_activation(post_act_fn)

    def forward(self, sample, condition=None):
        if condition is not None:
            sample = sample + self.cond_proj(condition)
        sample = self.linear_1(sample)

        if self.act is not None:
            sample = self.act(sample)

        sample = self.linear_2(sample)

        if self.post_act is not None:
            sample = self.post_act(sample)
        return sample

class TimestepEmbeddingAIT(nn.Module):
    def __init__(
        self,
        in_channels: int,
        time_embed_dim: int,
        act_fn: str = "silu",
    ):
        super().__init__()

        self.linear_1 = nn.Linear(in_channels, time_embed_dim, specialization="swish")
        self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim)

    def forward(self, sample):
        sample = self.linear_1(sample)
        sample = self.linear_2(sample)
        return sample

class TimestepsTorch(torch.nn.Module):
    def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float):
        super().__init__()
        self.num_channels = num_channels
        self.flip_sin_to_cos = False
        self.downscale_freq_shift = 1
    
    def get_timestep_embedding(
        self,
        timesteps: torch.Tensor,
        embedding_dim: int,
        flip_sin_to_cos: bool = False,
        downscale_freq_shift: float = 1,
        scale: float = 1,
        max_period: int = 10000,
    ):
        """
        This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.

        :param timesteps: a 1-D Tensor of N indices, one per batch element.
                        These may be fractional.
        :param embedding_dim: the dimension of the output. :param max_period: controls the minimum frequency of the
        embeddings. :return: an [N x dim] Tensor of positional embeddings.
        """
        assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"

        half_dim = embedding_dim // 2
        exponent = -math.log(max_period) * torch.arange(
            start=0, end=half_dim, dtype=torch.float32, device=timesteps.device
        )
        exponent = exponent / (half_dim - downscale_freq_shift)

        emb = torch.exp(exponent)
        emb = timesteps[:, None].float() * emb[None, :]

        # scale embeddings
        emb = scale * emb

        # concat sine and cosine embeddings
        emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)

        # flip sine and cosine embeddings
        if flip_sin_to_cos:
            emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)

        # zero pad
        if embedding_dim % 2 == 1:
            emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
        return emb

    def forward(self, timesteps):
        t_emb = self.get_timestep_embedding(
            timesteps,
            self.num_channels,
            # flip_sin_to_cos=self.flip_sin_to_cos,
            # downscale_freq_shift=self.downscale_freq_shift,
        )
        return t_emb

class TimestepsAIT(nn.Module):
    def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float):
        super().__init__()
        self.num_channels = num_channels
        self.flip_sin_to_cos = flip_sin_to_cos
        self.downscale_freq_shift = downscale_freq_shift
    def get_timestep_embedding(
        self,
        timesteps: Tensor,
        embedding_dim: int,
        flip_sin_to_cos: bool = False,
        downscale_freq_shift: float = 1,
        scale: float = 1,
        max_period: int = 10000,
    ):
        """
        This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.

        :param timesteps: a 1-D Tensor of N indices, one per batch element.
                        These may be fractional.
        :param embedding_dim: the dimension of the output. :param max_period: controls the minimum frequency of the
        embeddings. :return: an [N x dim] Tensor of positional embeddings.
        """
        assert timesteps._rank() == 1, "Timesteps should be a 1d-array"

        half_dim = embedding_dim // 2

        exponent = (-math.log(max_period)) * Tensor(
            shape=[half_dim], dtype="float16", name="arange"
        )

        exponent = exponent * (1.0 / (half_dim - downscale_freq_shift))

        emb = ops.exp(exponent)
        emb = ops.reshape()(timesteps, [-1, 1]) * ops.reshape()(emb, [1, -1])

        # scale embeddings
        emb = scale * emb

        # concat sine and cosine embeddings
        if flip_sin_to_cos:
            emb = ops.concatenate()(
                [ops.cos(emb), ops.sin(emb)],
                dim=-1,
            )
        else:
            emb = ops.concatenate()(
                [ops.sin(emb), ops.cos(emb)],
                dim=-1,
            )
        return emb

    def forward(self, timesteps):
        t_emb = self.get_timestep_embedding(
            timesteps,
            self.num_channels,
            flip_sin_to_cos=self.flip_sin_to_cos,
            downscale_freq_shift=self.downscale_freq_shift,
        )
        return t_emb

class TimeEmbTorch(torch.nn.Module):
    def __init__(self):
        super().__init__()
        # time
        block_out_channels = (320, 640, 1280, 1280)
        time_embed_dim = block_out_channels[0] * 4
        self.time_proj = TimestepsTorch(block_out_channels[0], True, 0)
        timestep_input_dim = block_out_channels[0]

        self.time_embedding = TimestepEmbeddingTorch(
            timestep_input_dim,
            time_embed_dim,
            act_fn="silu",
            # cond_proj_dim=time_cond_proj_dim,
        )
    
    def forward(self,sample,timestep,encoder_hidden_states):
        timesteps = timestep
        num_frames = sample.shape[2]
        timesteps = timesteps.expand(sample.shape[0])
        print("timesteps shape",timesteps.shape)
        t_emb = self.time_proj(timesteps)
        print("t_emb shape",t_emb.shape)
        # timesteps does not contain any weights and will always return f32 tensors
        # but time_embedding might actually be running in fp16. so we need to cast here.
        # there might be better ways to encapsulate this.
        t_emb = t_emb.half()
        emb = self.time_embedding(t_emb)
        print("emb shape",emb.shape)
        # emb = emb.repeat_interleave(repeats=num_frames, dim=0)
        # encoder_hidden_states = encoder_hidden_states.repeat_interleave(repeats=num_frames, dim=0)
        return emb

class TimeEmbAIT(nn.Module):
    def __init__(self):
        super().__init__()
        # time
        block_out_channels = (320, 640, 1280, 1280)
        time_embed_dim = block_out_channels[0] * 4
        self.time_proj = TimestepsAIT(block_out_channels[0], True, 0)
        timestep_input_dim = block_out_channels[0]

        self.time_embedding = TimestepEmbeddingAIT(
            timestep_input_dim,
            time_embed_dim,
            act_fn="silu",
            # cond_proj_dim=time_cond_proj_dim,
        )
    
    def forward(self,sample,timestep,encoder_hidden_states):
        timesteps = timestep
        # print("timesteps shape:",get_shape(timesteps))
        timesteps = ops.expand()(timesteps,[2])    
        print("timesteps expand shape:",get_shape(timesteps))
        t_emb = self.time_proj(timesteps)
        print("t_emb shape:",get_shape(t_emb))
        emb = self.time_embedding(t_emb)
        print("emb shape:",get_shape(emb))
        return emb


if __name__ == "__main__":

    # 1.pt model inference
    pt_model = TimeEmbTorch().eval().half().cuda()
    sample_pt_shape = [2,4,16,32,32]  # b,c,frames,h,w
    timestep_pt_shape = [1]
    encoder_hidden_states_pt_shape = [2,77,1024]
    sample_pt = torch.randn(sample_pt_shape).half().cuda()
    timestep_pt = torch.randn(timestep_pt_shape).half().cuda()
    encoder_hidden_states_pt = torch.randn(encoder_hidden_states_pt_shape).half().cuda()
    with torch.inference_mode():
        y_pt = pt_model.forward(sample_pt,timestep_pt,encoder_hidden_states_pt)
        print("y_pt shape:",y_pt.shape)

    # 2.ait model inference
    ait_model = TimeEmbAIT()
    ait_model.name_parameter_tensor()
    # forward need batch, num_channels, num_frames, height, width
    latent_model_input_ait = Tensor([2*16, 32, 32, 4], name="input0", is_input=True)
    timesteps_ait = Tensor([1], name="input1", is_input=True)
    text_embeddings_pt_ait = Tensor([2, 77, 1024], name="input2", is_input=True) 
    Y_ait = ait_model.forward(latent_model_input_ait,timesteps_ait,text_embeddings_pt_ait)
    Y_ait = mark_output(Y_ait, "Y")

    # state_dict_pt = cast(dict[str, torch.Tensor], pt_model.state_dict()) python3.8 not work
    state_dict_pt = cast(Dict[str, torch.Tensor], pt_model.state_dict())
    state_dict_ait = {}
    for key, value in state_dict_pt.items():
        key_ait = key.replace(".", "_")
        value = value.half().cuda().contiguous()
        state_dict_ait[key_ait] = value

    # print("state_dict_ait:",state_dict_ait)
    target = detect_target()
    work_dir = "./tmp"
    model_name = "AITModel"
    ait_module = compile_model(
        Y_ait,
        target,
        work_dir,
        model_name,
        constants=state_dict_ait,
    )
    
    x0_ait = sample_pt.clone().view(2, 16, 32, 32, 4).permute(0, 1, 3, 4, 2).reshape(2*16, 32, 32, 4).contiguous()
    x1_ait = timestep_pt.clone().contiguous()
    x2_ait = encoder_hidden_states_pt.clone().contiguous()
    # print("x0_ait shape:",x0_ait.shape)
    # print("x1_ait shape:",x1_ait.shape)
    # print("x2_ait shape:",x2_ait.shape)
    y_ait = inference(ait_module, [x0_ait,x1_ait,x2_ait])["Y"]
    tol = 1e-3
    torch.testing.assert_close(
        y_ait,
        y_pt,
        rtol=tol,
        atol=tol,
        msg=lambda msg: f"{msg}\n\npt ({y_pt.shape}):\n{y_pt}\n\nait ({y_ait.shape}):\n{y_ait}\n\n",
    )
  
    print("verified")
hlky commented

Errors from AIT itself are printed above the traceback.

[15:15:01] model_interface.cu:221: Error: Constant arange was not set! Set the value with set_constant.
Traceback (most recent call last):

Arange kernel is not implemented so the tensor is provided as a precomputed constant.

Tensor(
    shape=[half_dim], dtype="float16", name="arange"
)

See UNet2DCondition mapping

Thanks a lot.I used the mapping.However there is another error I am confusing.

[15:33:57] model_container.cu:87: Init AITemplate Runtime with 1 concurrency
<aitemplate.compiler.model.Model object at 0x7f07b9f33850>
[15:33:57] model_interface.cu:221: Error: Got wrong number of inputs; expected 1, got 3
Traceback (most recent call last):
  File "test.py", line 102, in <module>
    y_ait = inference(ait_module, [x_ait,x_ait,x_ait])["Y"]
  File "test.py", line 46, in inference
    exe_module.run_with_tensors(input_tensors, ys, graph_mode=graph_mode)
  File "/usr/local/lib/python3.8/dist-packages/aitemplate/compiler/model.py", line 597, in run_with_tensors
    outputs_ait = self.run(
  File "/usr/local/lib/python3.8/dist-packages/aitemplate/compiler/model.py", line 495, in run
    return self._run_impl(
  File "/usr/local/lib/python3.8/dist-packages/aitemplate/compiler/model.py", line 434, in _run_impl
    self.DLL.AITemplateModelContainerRun(
  File "/usr/local/lib/python3.8/dist-packages/aitemplate/compiler/model.py", line 196, in _wrapped_func
    raise RuntimeError(f"Error in function: {method.__name__}")
RuntimeError: Error in function: AITemplateModelContainerRun

I want three input Tensor with you demo,as follows

class AITModel(nn.Module):
    def __init__(self) -> None:
        super().__init__()

    def forward(self, hidden_states,x1,x2):
        return hidden_states

ait_model = AITModel()
ait_model.name_parameter_tensor()

X_ait0 = Tensor(input_shape_ait, name="x0", is_input=True)
X_ait1 = Tensor(input_shape_ait, name="x1", is_input=True)
X_ait2 = Tensor(input_shape_ait, name="x2", is_input=True)

Y_ait = ait_model.forward(X_ait0,X_ait1,X_ait2)

ait_module = compile_model(
    Y_ait,
    target,
    work_dir,
    model_name,
    constants=state_dict_ait,
)

x_ait = x_pt.clone().permute(0, 2, 3, 1).contiguous()
y_ait = inference(ait_module, [x_ait,x_ait,x_ait])["Y"]

It's okay in pytorch.

with torch.inference_mode():
    y_pt = pt_model.forward(x_pt,x_pt,x_pt)

Why it's saying expected 1 input while I am giving three when compiling and no errors.

hlky commented

AIT removes unused inputs

Thanks, it works.
I will try to verify numerical accuracy step by step.