cuda实现的self attention算子
cuda版本:10.2
显卡:NVIDIA GeForce RTX 2060
pytorch版本:1.9.0
cuda中实现了多个核函数,具体如下:
矩阵乘法核函数,使用CSE599W第五节所讲的GEMM算法
矩阵转置核函数
求行的最大值核函数,使用规约的方法实现(类似于倍增的**)
广播进行减操作和幂操作的核函数
求行的和核函数,使用规约的方法实现(类似于倍增的**)
广播进行除操作的核函数
灵活组合使用上面的核函数即可实现self attention
实际测试中,在输入维度为(500,1500),输出维度为(500,800)的情况下,和使用pytorch实现的self attention比较,使用运行十次取平均值的方法,执行时间对比如下:
(过大的数据显卡内存跑不下)
可以看到在数据较小的情况下,算法的性能是优于torch实现的算子的
算子的自动求导
batch的boradcast核函数的实现
感觉实现的过程有点像CSE599W作业2写TVM的流程,但是写cuda要比tvm要复杂一些,涉及一些并行的算法
实现了多头注意力机制及其反向传播
比torch实现的最navie的multi_head_attention要快很多,还未跟lightseq进行比较
详情可见文件multi_head_attention分析.md
结合cuBLAS、cu等库,进一步提升速度
当前版本cuda backward存在在大数据情况下与pytorch backward结果相差过多的问题,还在debuging