ROCm/rocBLAS

[Bug]: rocblas_gemm_ex with m==1 fp16 inputs/outputs f32 compute slower than a quite naive gemv kernel on MI100

Epliz opened this issue · 18 comments

Describe the bug

As described in the title, rocblas_gemm_ex seems quite suboptimal when m==1 inputs/outputs are fp16 and compute is fp32 on MI100.
A quite naive kernel I implemented beats it.

Causes ROCm/pytorch#1408 in pytorch.
It make LLM inference on Mistral 7b fp16 slower compared to what it could easily be.

To Reproduce

Here is a C++ reproducer:

#include <hip/hip_runtime.h>
#include <hip/hip_fp16.h>
#include <rocblas/rocblas.h>
#include <iostream>
#include <chrono>
#include <functional>


#define ROWS_PER_BLOCK 4
#define THREADS_PER_BLOCK 64

#define DIV_ROUND_UP(a, b) (((a) + (b) - 1) / (b))

#define FULL_MASK32 0xffffffff
#define FULL_MASK64 0xffffffffffffffff

#ifdef  __CUDA_ARCH__
#define __xx_shfl_down(mask, val, offset) __shfl_down_sync(mask, val, offset)
#elif defined(__HIP_PLATFORM_AMD__) // AMD
#define __xx_shfl_down(mask, val, offset) __shfl_down(val, offset)
#else
#error "Unsupported compiler"
#endif

__device__ float warpReduce(float val) {
  if (warpSize == 32) {
    for (int offset = 16; offset > 0; offset /= 2)
      val += __xx_shfl_down(FULL_MASK32, val, offset);
  }
  if (warpSize == 64) {
    for (int offset = 32; offset > 0; offset /= 2)
      val += __xx_shfl_down(FULL_MASK64, val, offset);

  }
  return val;
}

static inline void __device__ dot2(float& acc, const float2& a, const float2& b) {
  acc += a.x * b.x;
  acc += a.y * b.y;
}

template <typename T>
static inline const T* __device__ addr(const T* p, unsigned index) {
  // helps the AMDGPU compiler understand it can use the sgrp pair + single vgpr addressing mode
  unsigned byte_offset = sizeof(T) * index;
  const uint8_t* p8 = (const uint8_t*)p;
  return (const T*) (p8 + byte_offset);
}

__global__ void muillm_gemv_kernel(
    const half* __restrict__ W, // weight matrix - size N x K
    const half* __restrict__ B, // optional bias - size N
    const half* __restrict__ X, // input = size K
    half* __restrict__ Y, // output - size N
    unsigned N,
    unsigned K
) {
  int warpCounts = THREADS_PER_BLOCK / warpSize;
  int warpId = threadIdx.x / warpSize;
  int laneId = threadIdx.x % warpSize;

  // can process ROWS_PER_BLOCK rows
  // shared state to do the reductions
  __shared__ float shared_accs[ROWS_PER_BLOCK];

  // initialize the shared memory
  if (threadIdx.x < ROWS_PER_BLOCK) {
    shared_accs[threadIdx.x] = 0.f;
  }
  if (THREADS_PER_BLOCK > warpSize) {
    __syncthreads();
  }

  {
    int current_row = blockIdx.x * ROWS_PER_BLOCK + 0;
    if (current_row + 3 < N) {

      // compute the t-th element of Y. by doing the dot product with the
      // t-th row of W
      const half* W0 = &W[(current_row + 0) * K];
      const half* W1 = &W[(current_row + 1) * K];
      const half* W2 = &W[(current_row + 2) * K];
      const half* W3 = &W[(current_row + 3) * K];

      float acc0 = 0.f;
      float acc1 = 0.f;
      float acc2 = 0.f;
      float acc3 = 0.f;

      // do the dot product
      {
        unsigned k; // should be 2 * tidx ?
        //*
        for (k = threadIdx.x * 2; k + 1 < K; k += (THREADS_PER_BLOCK * 2)) {
          // vectorized
          float2 x = __half22float2(*((const half2*)addr(X, k)));
          float2 w0 = __half22float2(*((const half2*)addr(W0, k)));
          float2 w1 = __half22float2(*((const half2*)addr(W1, k)));
          float2 w2 = __half22float2(*((const half2*)addr(W2, k)));
          float2 w3 = __half22float2(*((const half2*)addr(W3, k)));

          dot2(acc0, w0, x);
          dot2(acc1, w1, x);
          dot2(acc2, w2, x);
          dot2(acc3, w3, x);
        }
        if (k < K) {
          // remainder
          float x = __half2float(*addr(X,k));
          float w0 = __half2float(*addr(W0,k));
          float w1 = __half2float(*addr(W1,k));
          float w2 = __half2float(*addr(W2,k));
          float w3 = __half2float(*addr(W3,k));
          acc0 += w0 * x;
          acc1 += w1 * x;
          acc2 += w2 * x;
          acc3 += w3 * x;
        }
      }

      // warp reduce
      acc0 = warpReduce(acc0);
      acc1 = warpReduce(acc1);
      acc2 = warpReduce(acc2);
      acc3 = warpReduce(acc3);

      // reduce accross warps
      if (laneId == 0) {
        atomicAdd(&shared_accs[0], acc0);
        atomicAdd(&shared_accs[1], acc1);
        atomicAdd(&shared_accs[2], acc2);
        atomicAdd(&shared_accs[3], acc3);
      }
    } else {
      for (int i = 0; i < ROWS_PER_BLOCK; i++) {
        // compute the t-th element of Y. by doing the dot product with the
        // t-th row of W
        int current_row = blockIdx.x * ROWS_PER_BLOCK + i;

        if (current_row >= N)
          break;

        const half* W_ = &W[current_row * K];
      
        // do the dot product
        float acc = 0.f;
        for (int k = threadIdx.x; k < K; k += THREADS_PER_BLOCK) {
          float w = __half2float(W_[k]);
          acc += w * __half2float(X[k]);
        }

        // warp reduce
        acc = warpReduce(acc);

        // reduce accross warps
        if (laneId == 0) {
          atomicAdd(&shared_accs[i], acc);
        }
      }
    }
  }

  if (THREADS_PER_BLOCK > warpSize) {
    __syncthreads();
  }

  // write out the results
  {
    if (threadIdx.x >= ROWS_PER_BLOCK)
      return;

    int current_row = blockIdx.x * ROWS_PER_BLOCK + threadIdx.x;

    if (current_row < N) {
      float acc = shared_accs[threadIdx.x]; // read the fully reduced value
      if (B != nullptr) { // add the bias first if there is one
        acc += __half2float(B[current_row]);
      }

      // write the output value
      Y[current_row] = __float2half(acc);
    }
  }
}
void muillm_linear_forward_cuda(
    const half* __restrict__ W, // size N x K
    const half* __restrict__ B, // size N
    const half* __restrict__ X, // size K
    half* __restrict__ Y, // size N
    unsigned N,
    unsigned K) {

  const int threads_per_blocks = THREADS_PER_BLOCK;
  const int num_blocks = DIV_ROUND_UP(N, ROWS_PER_BLOCK);

  muillm_gemv_kernel<<<num_blocks, threads_per_blocks, 0, 0>>>(
    W,
    B,
    X,
    Y,
    N,
    K
  );
}

static inline void rocblas_sgemv(rocblas_handle handle,
    const half* __restrict__ W, // size N x K
    const half* __restrict__ X, // size K
    half* __restrict__ Y, // size N
    unsigned N,
    unsigned K) {
  float alpha = 1.0f;
  float beta = 0.f;

  // adapted for row major from https://stackoverflow.com/questions/56043539/cublassgemm-row-major-multiplication
  rocblas_gemm_ex(handle,
                  rocblas_operation_none /*transA*/,
                  rocblas_operation_none /*transB*/,
                  1 /*m*/,
                  N /*n*/,
                  K /*k*/,
                  &alpha,
                  X /*a*/,
                  rocblas_datatype_f16_r /*a_type*/,
                  1 /*lda*/,
                  W /*b*/,
                  rocblas_datatype_f16_r /*b_type*/,
                  K /*ldb*/,
                  &beta,
                  nullptr /*c*/,
                  rocblas_datatype_f16_r /*c_type*/,
                  1 /*ldc*/,
                  Y /*d*/,
                  rocblas_datatype_f16_r /*d_type*/,
                  1 /*ldd*/,
                  rocblas_datatype_f32_r /*compute_type*/,
                  rocblas_gemm_algo_standard /*algo*/,
                  0 /*solution_index*/,
                  0 /*flags*/);
}

size_t timeus_func(size_t count, std::function<void(int)> f) {
  std::chrono::steady_clock::time_point begin = std::chrono::steady_clock::now();
  f(count);
  hipDeviceSynchronize();

  std::chrono::steady_clock::time_point end = std::chrono::steady_clock::now();

  return std::chrono::duration_cast<std::chrono::microseconds>(end - begin).count() / count;
}

int main(int argc, char** argv) {
  int in_features=4096, out_features=14336;
  int tot_features = in_features * out_features;

  // allocate matrices and vectors
  half* x_small = nullptr;
  half* x_big = nullptr;
  half* w_up = nullptr;
  half* w_down = nullptr;

  std::cout<<"Allocating memory..."<<std::endl;
  if (hipMalloc(&x_small, sizeof(half) * in_features) != hipSuccess) {
    return -1;
  }

  if (hipMalloc(&x_big, sizeof(half) * out_features) != hipSuccess) {
    return -1;
  }

  if (hipMalloc(&w_up, sizeof(half) * tot_features) != hipSuccess) {
    return -1;
  }

  if (hipMalloc(&w_down, sizeof(half) * tot_features) != hipSuccess) {
    return -1;
  }

  // set memory
  std::cout<<"Setting memory..."<<std::endl;
  if (hipMemsetD16(x_small, 0, in_features) != hipSuccess) {
    return -1;
  }
  if (hipMemsetD16(x_big, 0, out_features) != hipSuccess) {
    return -1;
  }
  if (hipMemsetD16(w_up, 0, tot_features) != hipSuccess) {
    return -1;
  }
  if (hipMemsetD16(w_down, 0, tot_features) != hipSuccess) {
    return -1;
  }

  //
  std::cout<<"Running..."<<std::endl;

  int count = 10000;

  {

    auto mui_prof = [=] (int count) {
      for (int i = 0; i < count; i++) {
        muillm_linear_forward_cuda(w_up, nullptr, x_small, x_big, out_features, in_features);
        muillm_linear_forward_cuda(w_down, nullptr, x_big, x_small, in_features, out_features);
      }
    };

    // warmup
    size_t discarded = timeus_func(
      10,
      mui_prof
    );

    // measurement
    size_t mui_time = timeus_func(
      count,
      mui_prof
    );

    std::cout<<"mui: "<<mui_time<<"us/loop"<<std::endl;
  }

  {// rocblas
    rocblas_initialize();
    rocblas_handle handle;
    if(rocblas_create_handle(&handle) != rocblas_status_success) return -3;

    auto rocblas_prof = [=] (int count) {
      for (int i = 0; i < count; i++) {
        rocblas_sgemv(handle, w_up, x_small, x_big, out_features, in_features);
        rocblas_sgemv(handle, w_down, x_big, x_small, in_features, out_features);
      }
    };

    // warmup
    size_t discarded = timeus_func(
      10,
      rocblas_prof
    );

    // measurement
    size_t rocblas_time = timeus_func(
      count,
      rocblas_prof
    );

    std::cout<<"rocblas: "<<rocblas_time<<"us/loop"<<std::endl;
  }

  std::cout<<"DONE"<<std::endl;
  return 0;
}

Expected behavior

It should be at least as fast as my naive kernel.
But running the above, I get:

Allocating memory...
Setting memory...
Running...
mui: 227us/loop
rocblas: 386us/loop
hipblas: 386us/loop
DONE

Environment

Hardware description
CPU AMD Ryzen 7 5800X3D 8-Core Processor
GPU AMD Instinct MI100
Software version
rocm-core v6.0.2.60002-115~22.04
rocblas v4.0.0.60002-115~22.04

environment.txt

Additional context

Add any other context about the problem here.

EDIT: put a better kernel than originally included one
EDIT2: put a better kernel

yeah this has been an issue for a while: #1238

I updated the kernel from my reproducer, it saturates memory bandwidth (contrary to rocBLAS).

I see that @daineAMD replied to the other issue, so mentioning here as well, in case that helps in any way.
To contextualize again if needed, improving rocblas_gemm_ex for cases where it corresponds to gemv ops is a very common pattern for LLM inference at batch size = 1 which gets benchmarked quite often.
Given that a ~100 lines kernel beats rocblas by 2x, I would recommend to put some efforts into this. At least for the matrix shapes of popular LLMs, you could make sure it gets decent performance.

its also pretty silly since just using the gemv kernels in these cases should be trivial

the suboptimiality of this is ofc also easly shown with rocblas's own tool:

rocblas-bench -f gemm_ex -m 1 -n 16192 -k 16192 --transposeA N --transposeB N -r s --compute_type s -i 50

transA,transB,M,N,K,alpha,lda,beta,ldb,ldc,ldd,batch_count,rocblas-Gflops,us
N,N,1,16192,16192,1,128,0,16192,128,128,1, 203.735, 2573.74

rocblas-bench -f gemv -r s -m 16192 -n 16192 --lda 16192 -i 50

transA,M,N,alpha,lda,incx,beta,incy,rocblas-Gflops,rocblas-GB/s,us
N,16192,16192,1,16192,1,0,1, 480.082, 960.224, 1092.3

also rocblas_hgemv would also be great since there is opportunity here to use dual-issue

Hi @Epliz, thanks for brining this up. Yes, the disparity between gemm with m == 1/n == 1 and gemv has been brought up in the past as noted by @IMbackK. Back when it was originally brought up, it wasn't straightforward on if the best approach would be to re-direct the gemm call to gemv (which has source kernels in rocblas) or to continue to gemm (which is handled within the Tensile library) since performance was somewhat of a mixed-bag; and handling this on a case-by-case basis seemed infeasible.

Regardless, it's good that this has been brought up again, and I'll discuss with the team on what the best approach is. If we can get gemv to outperform gemm in every case, then the changes to redirect to gemv would be straightforward, but most of the work would lie in ensuring that gemv is faster. I'll keep you updated with any progress here.

The request for rocblas_hgemv() has also been noted and I can discuss with the team about whether or not we plan on supporting this.

Thanks,
Daine

Hi @daineAMD

Thank you for the detailed comment on this matter and for:
The request for rocblas_hgemv() has also been noted and I can discuss with the team about whether or not we plan on supporting this.

Out of curiosity:
On initial experimentation with rocblas-bench i have been unable to find a configuration where gemm_ex beats gemv on gfx906, gfx908 or gfx1030, if you have some notes on which these could be this would be interesting to me from a performance optimization perspective in my code.

Hi @daineAMD ,

Following up after a week.
Do you have any example of a configuration where gemv is slower than gemm ?

If not, can you please proceed with making gemm call gemv for those cases?

If the rocBlas team cannot tackle this task, would a pull request from my side be potentially merged? I can sign whatever contribution agreement you might need.

Hi @Epliz and @IMbackK, sorry for the delay.

Looking at my past notes, it looks like the areas of most concern were where the incx parameter is large (with various exceptions), specifically gemm cases where (transA == transB == T && ldb >> 1) and (transA == transB == N && lda >> 1).
For example, the following gemm and gemv calls are essentially the same operation:
./rocblas-bench -f gemm -r f32_r --transposeA N --transposeB N -m 1 -n 2048 -k 2048 --alpha 1 --lda 2048 --beta 0 --ldb 2048 --ldc 1
and
./rocblas-bench -f gemv -r f32_r --transposeA T -m 2048 -n 2048 --lda 2048 --incx 2048. Note the large incx here which corresponds to the lda in the gemm call. You can try this out yourself, but I'm getting better performance with gemm here than gemv on MI100.

Other cases where I'm seeing gemm perform better than gemv is for small sizes, e.g.:
./rocblas-bench -f gemm -r f32_r --transposeA N --transposeB N -m 1 -n 1024 -k 1024 --alpha 1 --lda 1 --beta 0 --ldb 1024 --ldc 1
and
./rocblas-bench -f gemv -r f32_r --transposeA T -m 1024 -n 1024 --lda 1024 --incx 1

I have a ticket to investigate further to see if we can call gemv from cases where it outperforms gemm and/or see what optimizations can be done for the current gemv to make this easier; I'll be looking at this in the coming weeks.

You are free to take a look yourself and open a PR, you can take a look at the contributing guide if you're interested, but merging the PR will still take some time as most of the work still lies in ensuring no performance regressions.

Thanks again,
Daine

Hi @daineAMD,

Thank you for your examples, this has been useful in determining when to use gemv in my code to work around this issue an when not.
Since this issue has now been quiet for a month and the previous issue on this topic was never resolved after two years I think it prudent to follow up on this and inquire if any internal progress or a decision the way forward with this performance problem has been made.

Hi @IMbackK,

Yes it's good to keep this topic up-to-date since it's been delayed for so long, thanks for your reminder. There have been no decisions made on a way forward yet. Currently, we are working on some potential optimizations for the gemv function, so I thought it best to hold off on making any changes until I can evaluate the performance of any changes to gemv in case it makes the decision easier.

In the meantime, I've mocked up some changes to potentially allow users to opt-in to using gemv kernels from rocblas_gemm_ex() calls with m == 1 || n ==1 (and other restrictions). We'll be discussing this option once gemv changes mentioned prior are in.

Also, regarding half-precision gemv support, the following functions are in rocBLAS as of ROCm 6.0:

  • rocblas_hshgemv_batched() / rocblas_hshgemv_strided_batched()
  • rocblas_hssgemv_batched()/ rocblas_hssgemv_strided_batched()
  • rocblas_tstgemv_batched() / rocblas_tstgemv_strided_batched()
  • rocblas_tssgemv_batched() / rocblas_tssgemv_strided_batched()

You can see their definitions in rocblas_functions.h. The precision prefixes represent input-compute-output types (e.g. hss is half-precision input, single-precision compute and output). It looks like they weren't added in the docs until ROCm 6.2, so they should be in the rocBLAS Documentation with ROCm 6.2. Sorry for not mentioning them previously, they slipped my mind.

Thanks,
Daine

Thanks @daineAMD for the reply.
I still believe that if not always dispatching those cases to the gemv kernel, dispatching for configurations known to be faster with gemv would be great already.
If that would be helpful to you, I would be happy to provide some shapes that are used in open-weight LLMs where inference with batch size = 1 would see benefits from gemm to gemv lowering.

For example, for the mistral 7b model, the matrix shapes are:

  • 4096 x 1024 (k/v matrices)
  • 4096 x 4096 (q matrix)
  • 4096 x 14336 (gate/up proj matrices)
  • 14336 x 4096 (down proj matrix)
  • 4096 x 6144 (if fusing q k v matrices)

Hi @daineAMD,

Thank you for your quick update.

In the meantime, I've mocked up some changes to potentially allow users to opt-in to using gemv kernels from rocblas_gemm_ex() calls with m == 1 || n ==1 (and other restrictions). We'll be discussing this option once gemv changes mentioned prior are in.

Having this selectable via rocblas's api or evvar would work great for me as a interim solution and presumably also for @Epliz.

Also, regarding half-precision gemv support, the following functions are in rocBLAS as of ROCm 6.0:

* `rocblas_hshgemv_batched()` / `rocblas_hshgemv_strided_batched()`

* `rocblas_hssgemv_batched()`/ `rocblas_hssgemv_strided_batched()`

* `rocblas_tstgemv_batched()` / `rocblas_tstgemv_strided_batched()`

* `rocblas_tssgemv_batched()` / `rocblas_tssgemv_strided_batched()`

Indeed i was not aware of these functions due to the lack of documentation, thank you for bringing these to my attention! Thus far i have been up casting to fp32.

Thanks for your feedback on documentation @IMbackK, a missing space had obfuscated the changelog bullet for these functions which I just fixed and clarified in commit: f087847