[5.3.X][CUDA >= 11.0] `hipblasGemmEx` doesn't fully match `cublasGemmEx`
emankov opened this issue · 2 comments
emankov commented
The problem is with the penultimate argument hipblasDatatype_t computeType
, which doesn't match to cublasComputeType_t computeType
. cublasComputeType_t
appeared with CUDA 11.0. cublasGemmEx
used cudaDataType
instead of cublasComputeType_t
for its penultimate argument starting with CUDA 8.0 and till CUDA 11.0.
HIPBLAS_EXPORT hipblasStatus_t hipblasGemmEx(hipblasHandle_t handle,
hipblasOperation_t transA,
hipblasOperation_t transB,
int m,
int n,
int k,
const void* alpha,
const void* A,
hipblasDatatype_t aType,
int lda,
const void* B,
hipblasDatatype_t bType,
int ldb,
const void* beta,
void* C,
hipblasDatatype_t cType,
int ldc,
hipblasDatatype_t computeType,
hipblasGemmAlgo_t algo);
CUBLASAPI cublasStatus_t CUBLASWINAPI cublasGemmEx(cublasHandle_t handle,
cublasOperation_t transa,
cublasOperation_t transb,
int m,
int n,
int k,
const void* alpha, /* host or device pointer */
const void* A,
cudaDataType Atype,
int lda,
const void* B,
cudaDataType Btype,
int ldb,
const void* beta, /* host or device pointer */
void* C,
cudaDataType Ctype,
int ldc,
cublasComputeType_t computeType,
cublasGemmAlgo_t algo);
typedef enum
{
HIPBLAS_R_16F = 150, /**< 16 bit floating point, real */
HIPBLAS_R_32F = 151, /**< 32 bit floating point, real */
HIPBLAS_R_64F = 152, /**< 64 bit floating point, real */
HIPBLAS_C_16F = 153, /**< 16 bit floating point, complex */
HIPBLAS_C_32F = 154, /**< 32 bit floating point, complex */
HIPBLAS_C_64F = 155, /**< 64 bit floating point, complex */
HIPBLAS_R_8I = 160, /**< 8 bit signed integer, real */
HIPBLAS_R_8U = 161, /**< 8 bit unsigned integer, real */
HIPBLAS_R_32I = 162, /**< 32 bit signed integer, real */
HIPBLAS_R_32U = 163, /**< 32 bit unsigned integer, real */
HIPBLAS_C_8I = 164, /**< 8 bit signed integer, complex */
HIPBLAS_C_8U = 165, /**< 8 bit unsigned integer, complex */
HIPBLAS_C_32I = 166, /**< 32 bit signed integer, complex */
HIPBLAS_C_32U = 167, /**< 32 bit unsigned integer, complex */
HIPBLAS_R_16B = 168, /**< 16 bit bfloat, real */
HIPBLAS_C_16B = 169, /**< 16 bit bfloat, complex */
} hipblasDatatype_t;
typedef enum {
CUBLAS_COMPUTE_16F = 64, /* half - default */
CUBLAS_COMPUTE_16F_PEDANTIC = 65, /* half - pedantic */
CUBLAS_COMPUTE_32F = 68, /* float - default */
CUBLAS_COMPUTE_32F_PEDANTIC = 69, /* float - pedantic */
CUBLAS_COMPUTE_32F_FAST_16F = 74, /* float - fast, allows down-converting inputs to half or TF32 */
CUBLAS_COMPUTE_32F_FAST_16BF = 75, /* float - fast, allows down-converting inputs to bfloat16 or TF32 */
CUBLAS_COMPUTE_32F_FAST_TF32 = 77, /* float - fast, allows down-converting inputs to TF32 */
CUBLAS_COMPUTE_64F = 70, /* double - default */
CUBLAS_COMPUTE_64F_PEDANTIC = 71, /* double - pedantic */
CUBLAS_COMPUTE_32I = 72, /* signed 32-bit int - default */
CUBLAS_COMPUTE_32I_PEDANTIC = 73, /* signed 32-bit int - pedantic */
} cublasComputeType_t;
emankov commented
The same goes to:
cublasGemmBatchedEx
-> hipblasGemmBatchedEx
cublasGemmStridedBatchedEx
-> hipblasGemmStridedBatchedEx