linear+gelu fused operator is not supported in ACL
snadampal opened this issue · 8 comments
Output of 'strings libarm_compute.so | grep arm_compute_version':
arm_compute_version=v23.11 Build options: {'Werror': '0', 'debug': '0', 'neon': '1', 'opencl': '0', 'embed_kernels': '0', 'os': 'linux', 'arch': 'armv8a', 'build': 'native', 'multi_isa': '1', 'fixed_format_kernels': '1', 'openmp': '1', 'cppthreads': '0'} Git hash=b'add70ace1e57f65d1ae4d0cedaec6e4578cf87ff'
Platform:
AWS c7g.16xl
Operating System:
Ubuntu 22.04
Problem description:
PyTorch2.0 introduced torch.compile() for the neural network compilation. One of the important techniques the Graph compilation employs is the operator fusion. To execute those compiled graphs efficiently, the platform need to support the fused operators. For example, for Bert base model (I think any transformer model) inner_product+relu
,matmul+relu
(or gelu or tanh) are commonly fused in the linear layer.
The issue is ACL23.11 doesn't support the above mentioned operators, hence we are not able to take full advantage of the PyTorch Graph compilation optimizations on aarch64.
Steps to reproduce:
When you run the below script, you can see that the fused operators are falling back to onednn 'c' reference kernels because ACL doesn't support them.
pip3 install torch==2.1.1
export DNNL_VERBOSE=1
import torch
from transformers import BertTokenizer, BertModel
import torch._inductor.config as config
config.cpp.weight_prepack=True
config.freezing=True
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertModel.from_pretrained("bert-base-uncased").eval()
text = "Test bert base torch.compile on aarch64 with ACL"
encoded_input = tokenizer(text, return_tensors='pt')
model.eval()
model = torch.compile(model)
with torch.set_grad_enabled(False):
model(**encoded_input)
Note: On PyTorch main, I have disabled the operator fusion for aarch64 to be able to use at least the other optimizations from the compilation, here is the PR. So, please use PyTorch 2.1.1 to reproduce the issue.
Hi @snadampal
Thanks for raising this. We will discuss the feature request with the team.
oneDNN shouldn't be falling back to the reference kernels (i.e. ref
). acl_post_ops_t
should try to fuse the operators in ACL and, if it isn't able to, it should fall back to calling ACL activation functions as a separate layer. It's slower than fusion because you have an extra store and load for each element, but it's an order of magnitude faster than the reference kernels. Let me know if there's any cases like this, the fix may be relatively simple.
Also, do we know the relative importance of different activations and data types? I haven't done any in depth analysis but for compute bound activations like gelu or tanh, there may not be much benefit to fusing them over having a separate activation layer. For the simpler memory bound activations, there should be a larger benefit. I think non-leaky relu (α = 1) is already fused into quite a few kernels, although as far as I know, leaky relu is not yet.
Hi @jondea , torch compiled version of bert-base
(the script I provided above) has got attr-post-ops:eltwise_gelu_erf
post op which is not supported in ACL, hence falling back to c++ reference kernel for fp32, and is failing to create primitive for bf16 fast math mode (because there are no reference fastmath kernels)
post-op init is returning unimplemented from this
https://github.com/oneapi-src/oneDNN/blob/main/src/cpu/aarch64/acl_utils.cpp#L108
I'm not sure if the gap in ACL is only the fused kernel or even the individual kernels.
for fp32:
onednn_verbose,primitive,exec,cpu,inner_product,ref:any,forward_training,src_f32::blocked:ab::f0 wei_f32::blocked:Ab8a::f0 bia_f32::blocked:a::f0 dst_f32::blocked:ab::f0,attr-scratchpad:user attr-post-ops:eltwise_gelu_erf ,,mb28ic768oc3072,17.3999
for bf16 fast math mode:
RuntimeError: could not create a primitive descriptor for an inner product forward propagation primitive
Great, thanks for the reproducer. It looks like ACL does in fact have a GELU implementation, at least for NEON FP32
It should be straightforward to hook this up here:
and it will automatically get picked up by the acl_post_ops_t
inside acl_inner_product_t
.
I have made an internal issue to take a look at this and get back to you. Things are quite busy at the moment so I'll need to get back to you on timescales.
@snadampal we now have a PR up for ACL GELU erf in oneDNN: oneapi-src/oneDNN#1843. This should enable ACL primitives (including inner product) to be used when there's a GELU erf post op. This isn't a fusion in the sense that the activation happens inside the GEMM kernel, but it does mean that you can make use of the ACL accelerated kernels when there are post ops in oneDNN.
Hi @jondea , how about the fusion support for the other primitive and post-ops combinations? Could you please add support for matmul + post-ops like gelu/relu/erf/tanh as well?
At the oneDNN level, we should automatically support combining matmul/conv/inner product with any binary or eltwise post op supported by the equivalent standalone ACL primitive. So matmul/conv/inner + gelu/relu/erf/tanh should accelerated by ACL in oneDNN (GELU went into v3.5).