nod-ai/iree-amd-aie

matmul-elementwise bf16 model failed compilation

yzhang93 opened this issue ยท 6 comments

Input IR

!lhs = tensor<1024x512xbf16>
!rhs = tensor<512x1024xbf16>
!ele = tensor<1024x1024xf32>
!res = tensor<1024x1024xbf16>

func.func @matmul_elementwise_bf16(%lhs : !lhs, %rhs : !rhs, %ele : !ele) -> !res {
  %cst = arith.constant 0.0 : f32
  %0 = tensor.empty() : !ele
  %1 = tensor.empty() : !res
  %fill = linalg.fill ins(%cst : f32) outs(%0 : !ele) -> !ele
  %2 = linalg.matmul ins(%lhs, %rhs : !lhs, !rhs) outs(%fill : !ele) -> !ele
  %res = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%2, %ele : !ele, !ele) outs(%1 : !res) {
  ^bb0(%in: f32, %in_0: f32, %out: bf16):
    %11 = arith.addf %in, %in_0 : f32
    %12 = arith.truncf %11 : f32 to bf16
    linalg.yield %12 : bf16
  } -> !res
  return %res : !res
}

Error:

LLVM ERROR: unable to legalize instruction: %1730:_(<1024 x s16>) = G_SHUFFLE_VECTOR %1729:_(<1024 x s16>), %1475:_, shufflemask(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, 254, 255, 256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269, 270, 271, 272, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, 296, 297, 298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311, 312, 313, 314, 315, 316, 317, 318, 319, 320, 321, 322, 323, 324, 325, 326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, 338, 339, 340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353, 354, 355, 356, 357, 358, 359, 360, 361, 362, 363, 364, 365, 366, 367, 368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, 380, 381, 382, 383, 384, 385, 386, 387, 388, 389, 390, 391, 392, 393, 394, 395, 396, 397, 398, 399, 400, 401, 402, 403, 404, 405, 406, 407, 408, 409, 410, 411, 412, 413, 414, 415, 416, 417, 418, 419, 420, 421, 422, 423, 424, 425, 426, 427, 428, 429, 430, 431, 432, 433, 434, 435, 436, 437, 438, 439, 440, 441, 442, 443, 444, 445, 446, 447, 448, 449, 450, 451, 452, 453, 454, 455, 456, 457, 458, 459, 460, 461, 462, 463, 464, 465, 466, 467, 468, 469, 470, 471, 472, 473, 474, 475, 476, 477, 478, 479, 480, 481, 482, 483, 484, 485, 486, 487, 488, 489, 490, 491, 492, 493, 494, 495, 496, 497, 498, 499, 500, 501, 502, 503, 504, 505, 506, 507, 508, 509, 510, 511, 512, 513, 514, 515, 516, 517, 518, 519, 520, 521, 522, 523, 524, 525, 526, 527, 528, 529, 530, 531, 532, 533, 534, 535, 536, 537, 538, 539, 540, 541, 542, 543, 544, 545, 546, 547, 548, 549, 550, 551, 552, 553, 554, 555, 556, 557, 558, 559, 560, 561, 562, 563, 564, 565, 566, 567, 568, 569, 570, 571, 572, 573, 574, 575, 576, 577, 578, 579, 580, 581, 582, 583, 584, 585, 586, 587, 588, 589, 590, 591, 592, 593, 594, 595, 596, 597, 598, 599, 600, 601, 602, 603, 604, 605, 606, 607, 608, 609, 610, 611, 612, 613, 614, 615, 616, 617, 618, 619, 620, 621, 622, 623, 624, 625, 626, 627, 628, 629, 630, 631, 632, 633, 634, 635, 636, 637, 638, 639, 640, 641, 642, 643, 644, 645, 646, 647, 648, 649, 650, 651, 652, 653, 654, 655, 656, 657, 658, 659, 660, 661, 662, 663, 664, 665, 666, 667, 668, 669, 670, 671, 672, 673, 674, 675, 676, 677, 678, 679, 680, 681, 682, 683, 684, 685, 686, 687, 688, 689, 690, 691, 692, 693, 694, 695, 696, 697, 698, 699, 700, 701, 702, 703, 704, 705, 706, 707, 708, 709, 710, 711, 712, 713, 714, 715, 716, 717, 718, 719, 720, 721, 722, 723, 724, 725, 726, 727, 728, 729, 730, 731, 732, 733, 734, 735, 736, 737, 738, 739, 740, 741, 742, 743, 744, 745, 746, 747, 748, 749, 750, 751, 752, 753, 754, 755, 756, 757, 758, 759, 760, 761, 762, 763, 764, 765, 766, 767, 768, 769, 770, 771, 772, 773, 774, 775, 776, 777, 778, 779, 780, 781, 782, 783, 784, 785, 786, 787, 788, 789, 790, 791, 792, 793, 794, 795, 796, 797, 798, 799, 800, 801, 802, 803, 804, 805, 806, 807, 808, 809, 810, 811, 812, 813, 814, 815, 816, 817, 818, 819, 820, 821, 822, 823, 824, 825, 826, 827, 828, 829, 830, 831, 832, 833, 834, 835, 836, 837, 838, 839, 840, 841, 842, 843, 844, 845, 846, 847, 848, 849, 850, 851, 852, 853, 854, 855, 856, 857, 858, 859, 860, 861, 862, 863, 864, 865, 866, 867, 868, 869, 870, 871, 872, 873, 874, 875, 876, 877, 878, 879, 880, 881, 882, 883, 884, 885, 886, 887, 888, 889, 890, 891, 892, 893, 894, 895, 896, 897, 898, 899, 900, 901, 902, 903, 904, 905, 906, 907, 908, 909, 910, 911, 912, 913, 914, 915, 916, 917, 918, 919, 920, 921, 922, 923, 924, 925, 926, 927, 928, 929, 930, 931, 932, 933, 934, 935, 936, 937, 938, 939, 940, 941, 942, 943, 944, 945, 946, 947, 948, 949, 950, 951, 952, 953, 954, 955, 956, 957, 958, 959, 960, 961, 962, 963, 964, 965, 966, 967, 968, 969, 970, 971, 972, 973, 974, 975, 976, 977, 978, 979, 980, 981, 982, 983, 984, 985, 986, 987, 988, 989, 990, 991, 992, 993, 994, 995, 996, 997, 998, 999, 1000, 1001, 1002, 1003, 1004, 1005, 1006, 1007, 1008, 1009, 1010, 1011, 1012, 1013, 1014, 1015, 1016, 1017, 1018, 1019, 1024, 1025, 1026, 1027) (in function: core_0_2)
PLEASE submit a bug report to https://github.com/llvm/llvm-project/issues/ and include the crash backtrace.
Stack dump:
0.	Program arguments: /proj/xsjhdstaff4/vivizhan/llvm-aie/install/bin/llc /proj/xsjhdstaff4/vivizhan/iree-amd-aie/build_tools/ci/cpu_comparison/test_result_bf16/module_matmul_elementwise_bf16_dispatch_0_amdaie_xclbin_fb/input.opt.ll -O2 --march=aie2 --function-sections --filetype=obj -o /proj/xsjhdstaff4/vivizhan/iree-amd-aie/build_tools/ci/cpu_comparison/test_result_bf16/module_matmul_elementwise_bf16_dispatch_0_amdaie_xclbin_fb/input.o
1.	Running pass 'Function Pass Manager' on module '/proj/xsjhdstaff4/vivizhan/iree-amd-aie/build_tools/ci/cpu_comparison/test_result_bf16/module_matmul_elementwise_bf16_dispatch_0_amdaie_xclbin_fb/input.opt.ll'.
2.	Running pass 'Legalizer' on function '@core_0_2'
 #0 0x000055ae9b6ceebf llvm::sys::PrintStackTrace(llvm::raw_ostream&, int) /proj/rdi/staff/vivizhan/llvm-aie/llvm/lib/Support/Unix/Signals.inc:567:22
 #1 0x000055ae9b6ccfc4 llvm::sys::RunSignalHandlers() /proj/rdi/staff/vivizhan/llvm-aie/llvm/lib/Support/Signals.cpp:104:20
 #2 0x000055ae9b6cd146 SignalHandler(int) /proj/rdi/staff/vivizhan/llvm-aie/llvm/lib/Support/Unix/Signals.inc:412:1
 #3 0x00007fa6da842520 (/lib/x86_64-linux-gnu/libc.so.6+0x42520)
 #4 0x00007fa6da8969fc __pthread_kill_implementation ./nptl/pthread_kill.c:44:76
 #5 0x00007fa6da8969fc __pthread_kill_internal ./nptl/pthread_kill.c:78:10
 #6 0x00007fa6da8969fc pthread_kill ./nptl/pthread_kill.c:89:10
 #7 0x00007fa6da842476 gsignal ./signal/../sysdeps/posix/raise.c:27:6
 #8 0x00007fa6da8287f3 abort ./stdlib/abort.c:81:7
 #9 0x000055ae9b6438d3 (/proj/xsjhdstaff4/vivizhan/llvm-aie/install/bin/llc+0x2cd98d3)
#10 0x000055ae9bb25532 reportGISelDiagnostic(llvm::DiagnosticSeverity, llvm::MachineFunction&, llvm::TargetPassConfig const&, llvm::MachineOptimizationRemarkEmitter&, llvm::MachineOptimizationRemarkMissed&) /proj/rdi/staff/vivizhan/llvm-aie/llvm/lib/CodeGen/GlobalISel/Utils.cpp:257:23
#11 0x000055ae9bb26f5b llvm::DiagnosticInfoOptimizationBase::~DiagnosticInfoOptimizationBase() /proj/rdi/staff/vivizhan/llvm-aie/llvm/include/llvm/IR/DiagnosticInfo.h:413:7
#12 0x000055ae9bb26f5b llvm::DiagnosticInfoMIROptimization::~DiagnosticInfoMIROptimization() /proj/rdi/staff/vivizhan/llvm-aie/llvm/include/llvm/CodeGen/MachineOptimizationRemarkEmitter.h:30:7
#13 0x000055ae9bb26f5b llvm::MachineOptimizationRemarkMissed::~MachineOptimizationRemarkMissed() /proj/rdi/staff/vivizhan/llvm-aie/llvm/include/llvm/CodeGen/MachineOptimizationRemarkEmitter.h:84:7
#14 0x000055ae9bb26f5b llvm::reportGISelFailure(llvm::MachineFunction&, llvm::TargetPassConfig const&, llvm::MachineOptimizationRemarkEmitter&, char const*, llvm::StringRef, llvm::MachineInstr const&) /proj/rdi/staff/vivizhan/llvm-aie/llvm/lib/CodeGen/GlobalISel/Utils.cpp:286:1
#15 0x000055ae9babdb82 llvm::Legalizer::runOnMachineFunction(llvm::MachineFunction&) (.part.0) /proj/rdi/staff/vivizhan/llvm-aie/llvm/lib/CodeGen/GlobalISel/Legalizer.cpp:348:12
#16 0x000055ae9a7f9b3b llvm::MachineFunctionPass::runOnFunction(llvm::Function&) (.part.0) /proj/rdi/staff/vivizhan/llvm-aie/llvm/lib/CodeGen/MachineFunctionPass.cpp:91:33
#17 0x000055ae9ad2eaec llvm::FPPassManager::runOnFunction(llvm::Function&) /proj/rdi/staff/vivizhan/llvm-aie/llvm/lib/IR/LegacyPassManager.cpp:1440:7
#18 0x000055ae9ad2ed19 llvm::ilist_node_base<true>::getNext() const /proj/rdi/staff/vivizhan/llvm-aie/llvm/include/llvm/ADT/ilist_node_base.h:43:45
#19 0x000055ae9ad2ed19 llvm::ilist_node_impl<llvm::ilist_detail::node_options<llvm::Function, true, false, void>>::getNext() /proj/rdi/staff/vivizhan/llvm-aie/llvm/include/llvm/ADT/ilist_node.h:67:66
#20 0x000055ae9ad2ed19 llvm::ilist_iterator<llvm::ilist_detail::node_options<llvm::Function, true, false, void>, false, false>::operator++() /proj/rdi/staff/vivizhan/llvm-aie/llvm/include/llvm/ADT/ilist_iterator.h:157:25
#21 0x000055ae9ad2ed19 llvm::FPPassManager::runOnModule(llvm::Module&) /proj/rdi/staff/vivizhan/llvm-aie/llvm/lib/IR/LegacyPassManager.cpp:1475:22
#22 0x000055ae9ad2f59e runOnModule /proj/rdi/staff/vivizhan/llvm-aie/llvm/lib/IR/LegacyPassManager.cpp:1552:7
#23 0x000055ae9ad2f59e llvm::legacy::PassManagerImpl::run(llvm::Module&) /proj/rdi/staff/vivizhan/llvm-aie/llvm/lib/IR/LegacyPassManager.cpp:535:55
#24 0x000055ae99e4601e compileModule(char**, llvm::LLVMContext&) /proj/rdi/staff/vivizhan/llvm-aie/llvm/tools/llc/llc.cpp:736:66
#25 0x000055ae99e46f86 main /proj/rdi/staff/vivizhan/llvm-aie/llvm/tools/llc/llc.cpp:420:35
#26 0x00007fa6da829d90 __libc_start_call_main ./csu/../sysdeps/nptl/libc_start_call_main.h:58:16
#27 0x00007fa6da829e40 call_init ./csu/../csu/libc-start.c:128:20
#28 0x00007fa6da829e40 __libc_start_main ./csu/../csu/libc-start.c:379:5
#29 0x000055ae99e3a2e5 _start (/proj/xsjhdstaff4/vivizhan/llvm-aie/install/bin/llc+0x14d02e5)

In contrast, bf16-f32 model (without arith.truncf %11 : f32 to bf16) as below doesn't have such error.

!lhs = tensor<1024x512xbf16>
!rhs = tensor<512x1024xbf16>
!ele = tensor<1024x1024xf32>
!res = tensor<1024x1024xf32>

func.func @matmul_elementwise_bf16(%lhs : !lhs, %rhs : !rhs, %ele : !ele) -> !res {
  %cst = arith.constant 0.0 : f32
  %0 = tensor.empty() : !ele
  %fill = linalg.fill ins(%cst : f32) outs(%0 : !ele) -> !ele
  %2 = linalg.matmul ins(%lhs, %rhs : !lhs, !rhs) outs(%fill : !ele) -> !ele
  %res = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%2, %ele : !ele, !ele) outs(%0 : !ele) {
  ^bb0(%in: f32, %in_0: f32, %out: f32):
    %11 = arith.addf %in, %in_0 : f32
    linalg.yield %11 : f32
  } -> !res
  return %res : !res
}

@MaheshRavishankar @stephenneuendorffer @newling @erwei-xilinx Any insight about the issue?

I dont know if Peano handles bf16 natively.

I believe there's work going on to implement shuffle_vector. currently the assumption is that the vector ops always go through intrinsics. FYI, for Peano issues, you're better off capturing the .ll code and creating an issue in the peano repo.

Peano does support bf16 types, and there is indeed work to support more and more cases of generic shuffle_vector. However, I think the problem here is rather that %1730:_(<1024 x s16>) is a huge vector, and we do not have the capability yet to properly legalize those. As Stephen said, it would be very useful if you could get us a small .ll reproducer, then we can investigate what's really happening here :)

Support for G_SHUFFLE_VECTOR for Peano is soon under review, so that should land soonish. The failing instruction asks for 16-bit so it is not the support for bf in any case. There are two problems with the code as is:

  • Outputted vector is 4 times larger than the largest register we have available, the code doesn't currently handle that. This is something can be added Peano though, I just haven't got around to dealing with that yet.
  • In the generic case is G_SHUFFLE_VECTOR incredibly slow since it needs to extract each value of the vector in turn and then reconstruct the vector by element. This instruction only changes the last 64 bytes, so it will do 32.640 bytes of useless memory operations. We can reduce this a lot by matching the patterns that you depend on and replace it with better instructions, but that does require us to know which G_SHUFFLE_VECTOR masks are required.

Thanks @stephenneuendorffer @gbossu @ValentijnvdBeek for looking into the issue! Here are the .ll files generated from the above example. Please let me know if you need me to provide other sources.
input_ll.zip