facebookincubator/AITemplate

Low performance from unnecessary permutations

jonpryai opened this issue · 18 comments

I'm using fx2ait to load an onnx graph. After optimization, the results are not good.

BS: 1, PT Eager time per iter: 0.01654841552734375ms, PT Eager QPS: 60.43, FX2AIT time per iter: 0.024108586425781252ms, FX2AIT Eager QPS: 41.48, Speedup: 0.69

Let alone compared to tensorRt. I profiled the optimized graph and found:

61.9 11,952,884,520 64,800 184,458.1 123,999.0 18,112 1,017,919 216,853.6 void ::PermuteKernel<(unsigned long)4, (unsigned long)2, int>(::PermuteKernelPara…

Analyzing this in nsys, i see what is happening is that the graph is consistently doing:

permute -> element wise addition -> permute.

These permutations don't do anything because the element wise operator doesn't care about the ordering.

How to fix?

Hi @jonpryai, thanks for flagging this. It does seem like at least one of the permutes could be redundant. But without a minimal repro, it's hard to determine whether they should be removed and whether we need a pass to handle this case.

Do you mind sharing details on how to reproduce this? Thanks!

I use this to compile

import onnx
from onnx2torch import convert
from fx2ait.example.benchmark_utils import benchmark_function

batch_size = 1
class TestModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        onnx_model = onnx.load("test.onnx")
        self.mod = convert(onnx_model)
    def forward(self, x):
        return self.mod(x)

model = TestModule().cuda().half()
inputs = [torch.randn(batch_size, 16, 224, 224).half().cuda()]
benchmark_function(
    self.__class__.__name__,
    100,
    model,
    inputs,
)

Profiling the above is difficult because of all the compiling and profiling. I locate the test.so in /tmp and copy it to the current dir.
Then run:

import unittest

import torch
import torchvision
import onnx
import os
from aitemplate.compiler import compile_model, Model

from onnx2pytorch import ConvertModel
from onnx2torch import convert
from fx2ait.example.benchmark_utils import benchmark_function, verify_accuracy

def benchmark(model_name, batch_size, mod=None, graph_mode=True):
    # Load params
    #cuda_params = export_to_torch_tensor(model_name)
[test.zip](https://github.com/facebookincubator/AITemplate/files/12752757/test.zip)

    # Load compiled model
    if mod is None:
        model_name = f"{model_name}_{batch_size}"
        mod = Model(os.path.join("./", "test.so"))


    # prepare input/output tensor
    x_input = torch.randn([batch_size, 16, 224, 224]).cuda().half()
    x_input = x_input.contiguous()
    y_output = torch.zeros([batch_size, 64, 56, 56]).cuda().half()
    y_output = y_output.contiguous()

    # warm up
    t, _, __ = mod.benchmark_with_tensors(
        [x_input],
        [y_output],
        count=100,
        repeat=4,
        graph_mode=graph_mode,
    )
    # benchmark
    t, _, __ = mod.benchmark_with_tensors(
        [x_input],
        [y_output],
        count=100,
        repeat=4,
        graph_mode=graph_mode,
    )
    print(f"batch_size: {batch_size}, latency: {t}")
    dev_flag = os.environ.get("HIP_VISIBLE_DEVICES", "-1")
    dev_flag = dev_flag.replace(",", "_")
    with open(f"resnet50_ait_benchmark_dev_{dev_flag}.txt", "a") as f:
        f.write(f"batch_size: {batch_size}, latency: {t}\n")


if __name__ == "__main__":
    benchmark("",1)```

Example onnx section:

test.zip

For security reasons, I'm unable to download external files. I hope you understand.

It'll be easier to reproduce if you can share the model graph that AIT dumps automatically via dump_graph_debug_str_to_file. You'll need to set the following environment variable: LOGLEVEL=DEBUG when you compile the model. The files will appear in your workdir.

Once you do that, could you share the contents of memory_planning_pseudo_code.txt?

(Tensor(name=permute_0_0, shape=[1, 224, 224, 16])) 
= permute()(
Tensor(name=x, shape=[1, 16, 224, 224]))

# conv2d_bias_1
(Tensor(name=conv2d_bias_1_0, shape=[1, 112, 112, 32])) 
= conv2d_bias(dilate=1, group=1, pad=1, stride=2)(
Tensor(name=permute_0_0, shape=[1, 224, 224, 16]), Tensor(name=mod_level1_level1_0_Conv_weight, shape=[32, 3, 3, 16], data=(9216 bytes)), Tensor(name=mod_level1_level1_0_Conv_bias, shape=[32], data=(64 bytes)))

# permute_2
(Tensor(name=permute_2_0, shape=[1, 32, 112, 112])) 
= permute()(
Tensor(name=conv2d_bias_1_0, shape=[1, 112, 112, 32]))

# fused_elementwise_19
(Tensor(name=elementwise_3_0, shape=[1, 32, 112, 112])) 
= fused_elementwise(func=[<FuncEnum.RELU: 18>])(
Tensor(name=permute_2_0, shape=[1, 32, 112, 112]))

# permute_4
(Tensor(name=permute_4_0, shape=[1, 112, 112, 32])) 
= permute()(
Tensor(name=elementwise_3_0, shape=[1, 32, 112, 112]))

# permute_4
(Tensor(name=permute_5_0, shape=[1, 112, 112, 32])) 
= permute()(
Tensor(name=elementwise_3_0, shape=[1, 32, 112, 112]))

# conv2d_bias_6
(Tensor(name=conv2d_bias_6_0, shape=[1, 56, 56, 64])) 
= conv2d_bias(dilate=1, group=1, pad=1, stride=2)(
Tensor(name=permute_5_0, shape=[1, 112, 112, 32]), Tensor(name=mod_level2_tree1_conv1_Conv_weight, shape=[64, 3, 3, 32], data=(36864 bytes)), Tensor(name=mod_level2_tree1_conv1_Conv_bias, shape=[64], data=(128 bytes)))

# permute_7
(Tensor(name=permute_7_0, shape=[1, 64, 56, 56])) 
= permute()(
Tensor(name=conv2d_bias_6_0, shape=[1, 56, 56, 64]))

# fused_elementwise_20
(Tensor(name=elementwise_8_0, shape=[1, 64, 56, 56])) 
= fused_elementwise(func=[<FuncEnum.RELU: 18>])(
Tensor(name=permute_7_0, shape=[1, 64, 56, 56]))

# permute_9
(Tensor(name=permute_9_0, shape=[1, 56, 56, 64])) 
= permute()(
Tensor(name=elementwise_8_0, shape=[1, 64, 56, 56]))

# conv2d_bias_10
(Tensor(name=conv2d_bias_10_0, shape=[1, 56, 56, 64])) 
= conv2d_bias(dilate=1, group=1, pad=1, stride=1)(
Tensor(name=permute_9_0, shape=[1, 56, 56, 64]), Tensor(name=mod_level2_tree1_conv2_Conv_weight, shape=[64, 3, 3, 64], data=(73728 bytes)), Tensor(name=mod_level2_tree1_conv2_Conv_bias, shape=[64], data=(128 bytes)))

# permute_7
(Tensor(name=permute_11_0, shape=[1, 64, 56, 56])) 
= permute()(
Tensor(name=conv2d_bias_10_0, shape=[1, 56, 56, 64]))

# max_pool2d_12
(Tensor(name=max_pool2d_12_0, shape=[1, 56, 56, 32])) 
= max_pool2d(stride=2, pad=0, kernel_size=2, reduce_func=max)(
Tensor(name=permute_4_0, shape=[1, 112, 112, 32]))

# conv2d_bias_15
(Tensor(name=conv2d_bias_15_0, shape=[1, 56, 56, 64])) 
= conv2d_bias(dilate=1, group=1, pad=0, stride=1)(
Tensor(name=max_pool2d_12_0, shape=[1, 56, 56, 32]), Tensor(name=mod_level2_project_project_0_Conv_weight, shape=[64, 1, 1, 32], data=(4096 bytes)), Tensor(name=mod_level2_project_project_0_Conv_bias, shape=[64], data=(128 bytes)))

# permute_7
(Tensor(name=permute_16_0, shape=[1, 64, 56, 56])) 
= permute()(
Tensor(name=conv2d_bias_15_0, shape=[1, 56, 56, 64]))

# fused_elementwise_21
(Tensor(name=output_0, shape=[1, 64, 56, 56])) 
= fused_elementwise(func=[<FuncEnum.ADD: 1>, <FuncEnum.RELU: 18>])(
Tensor(name=permute_11_0, shape=[1, 64, 56, 56]), Tensor(name=permute_16_0, shape=[1, 64, 56, 56]))```

This image of the network might be helpful.

Screenshot from 2023-10-05 09-44-11

It does seem like either permute2 or permute4 can be removed here. It'll be easier to remove permute_2 imo.

And sorry for the delay, but this is what I believe we need:

  1. Find the conditions for removing permute_2.
    • We can make the conditions specific for your graph (i.e. only when the middle op is an elementwise-relu/gelu/etc.).
  2. Remove the first permute.
  3. Take the first permute's input (conv2d_bias_1_0) and make it the new input for the middle op (fused_elementwise_19).
  4. Confirm the shapes are correct for middle op and the remaining permute.
  5. Write a test case and confirm its accuracy.

Here's some pointers:

Lmk if there's any questions there.

I am not very familiar with the code, so I could be wrong. But my first impression looking at this is while the optimizer is able to look at different orderings, NHWC and NCHW for the conv2d, for some reason it is married to NCHW for the elementwise, and maybe doesn't take into account the permutation cost.

I think that both permute_2 and permute_4 can be removed. There's also 2 copies of permute_4 that yield exactly the same tensor. What is happening here is:

conv2d(NHWC) -> toNCHW -> elementWise -> toNHWC
                                      -> toNHWC

which is the same thing as
conv2d(NHWC) -> elementWise

Ah I see, both permutes can definitely be removed in that case. And I'm not sure which pass introduces them in the first place.

Do you still have the dumped graphs in your directory? We can see which pass adds the permutes by looking at the {passname}_pseudo_code.txt.

They are present in everything except toposort_pseudo_code.txt. So bind_constants pass is causing it?

Actually, that's not true. It's even in the toposort, just the nodes haven't been annotated yet.

(Tensor(name=None, shape=[1, 224, 224, 16])) 
= permute()(
Tensor(name=x, shape=[1, 16, 224, 224]))

# None
(Tensor(name=None, shape=[1, 112, 112, 32])) 
= conv2d_bias(dilate=1, group=1, pad=1, stride=2)(
Tensor(name=None, shape=[1, 224, 224, 16]), Tensor(name=mod_level1_level1_0_Conv_weight, shape=[32, 3, 3, 16], data=(9216 bytes)), Tensor(name=mod_level1_level1_0_Conv_bias, shape=[32], data=(64 bytes)))

# None
(Tensor(name=None, shape=[1, 32, 112, 112])) 
= permute()(
Tensor(name=None, shape=[1, 112, 112, 32]))

# None
(Tensor(name=None, shape=[1, 32, 112, 112])) 
= elementwise(func=FuncEnum.RELU)(
Tensor(name=None, shape=[1, 32, 112, 112]))

# None
(Tensor(name=None, shape=[1, 112, 112, 32])) 
= permute()(
Tensor(name=None, shape=[1, 32, 112, 112]))

# None
(Tensor(name=None, shape=[1, 112, 112, 32])) 
= permute()(
Tensor(name=None, shape=[1, 32, 112, 112]))

# None
(Tensor(name=None, shape=[1, 56, 56, 64])) 
= conv2d_bias(dilate=1, group=1, pad=1, stride=2)(
Tensor(name=None, shape=[1, 112, 112, 32]), Tensor(name=mod_level2_tree1_conv1_Conv_weight, shape=[64, 3, 3, 32], data=(36864 bytes)), Tensor(name=mod_level2_tree1_conv1_Conv_bias, shape=[64], data=(128 bytes)))

# None
(Tensor(name=None, shape=[1, 64, 56, 56])) 
= permute()(
Tensor(name=None, shape=[1, 56, 56, 64]))

# None
(Tensor(name=None, shape=[1, 64, 56, 56])) 
= elementwise(func=FuncEnum.RELU)(
Tensor(name=None, shape=[1, 64, 56, 56]))

# None
(Tensor(name=None, shape=[1, 56, 56, 64])) 
= permute()(
Tensor(name=None, shape=[1, 64, 56, 56]))

# None
(Tensor(name=None, shape=[1, 56, 56, 64])) 
= conv2d_bias(dilate=1, group=1, pad=1, stride=1)(
Tensor(name=None, shape=[1, 56, 56, 64]), Tensor(name=mod_level2_tree1_conv2_Conv_weight, shape=[64, 3, 3, 64], data=(73728 bytes)), Tensor(name=mod_level2_tree1_conv2_Conv_bias, shape=[64], data=(128 bytes)))

# None
(Tensor(name=None, shape=[1, 64, 56, 56])) 
= permute()(
Tensor(name=None, shape=[1, 56, 56, 64]))

# None
(Tensor(name=None, shape=[1, 56, 56, 32])) 
= max_pool2d(stride=2, pad=0, kernel_size=2, reduce_func=max)(
Tensor(name=None, shape=[1, 112, 112, 32]))

# None
(Tensor(name=None, shape=[1, 32, 56, 56])) 
= permute()(
Tensor(name=None, shape=[1, 56, 56, 32]))

# None
(Tensor(name=None, shape=[1, 56, 56, 32])) 
= permute()(
Tensor(name=None, shape=[1, 32, 56, 56]))

# None
(Tensor(name=None, shape=[1, 56, 56, 64])) 
= conv2d_bias(dilate=1, group=1, pad=0, stride=1)(
Tensor(name=None, shape=[1, 56, 56, 32]), Tensor(name=mod_level2_project_project_0_Conv_weight, shape=[64, 1, 1, 32], data=(4096 bytes)), Tensor(name=mod_level2_project_project_0_Conv_bias, shape=[64], data=(128 bytes)))

# None
(Tensor(name=None, shape=[1, 64, 56, 56])) 
= permute()(
Tensor(name=None, shape=[1, 56, 56, 64]))

# None
(Tensor(name=None, shape=[1, 64, 56, 56])) 
= elementwise(func=FuncEnum.ADD)(
Tensor(name=None, shape=[1, 64, 56, 56]), Tensor(name=None, shape=[1, 64, 56, 56]))

# None
(Tensor(name=output_0, shape=[1, 64, 56, 56])) 
= elementwise(func=FuncEnum.RELU)(
Tensor(name=None, shape=[1, 64, 56, 56]))

Is it possible these nodes are being inserted by fxt2ai?

It could be fx2ait but it may also be onnx2torch.

I'm curious if replicating the model in Pytorch then using fx2ait will give us the same graph. If not, then I assume it's onnx2torch.

model gv

The permutes do not appear to be in the converted pytorch model. The permutes are present in the AITModel after the trace is performed.

You're right, the permutes are being added in fx2ait. The result from each conv2d is being permuted via ait_nhwc2nchw (here).

AIT does that because PyTorch takes channel-first tensors for conv, maxpool, etc., whereas, AIT takes channel-last tensors.

A potential workaround is to add a permute after each Conv2D? cc: @chenyang78

Is it possible to just make all the elementwise ops also do the permutation, then we will end up with a graph that is like

toNCHW -> conv2d -> toNHWC -> toHCHW -> elementWise -> to NHWC

then the remove permutations pass will find the redundant permutes

It sounds like that could work.

But would it be possible to try this?

  1. Permute your tensor so it's channel-last
  2. Set set_tensor_layout_policy(false) before lowering your model -- this avoids the permutes after conv2d
xmfbit commented

@jonpryai hi, have you solved the problem?

@xmfbit No not really. I am just trying to quickly see what the inference performance of a model would be with AITemplate. I'm wondering if instead of an onnx model, an FX graph may work correctly? Otherwise it may actually be easier to write the code to create an AITemplate model instead of trying to fix fxt2ait.

Trying to import a typical dla34 model gives a good example of the issues.