Some generated cuda kernel's input's shape is 0
VincentXWD opened this issue · 0 comments
VincentXWD commented
Hello, I noticed that some generated cuda kernel's input's shape is 0. Here is the hidet python model-define code:
I wonder know it would happen and is it a bug? Thanks.
import torch
import torch._dynamo
from torch import nn
import hidet
import math
class LayerNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-12):
"""Construct a layernorm module in the TF style (epsilon inside the square root).
"""
super(LayerNorm, self).__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.bias = nn.Parameter(torch.zeros(hidden_size))
self.variance_epsilon = eps
def forward(self, x):
u = x.mean(-1, keepdim=True)
s = (x - u).pow(2).mean(-1, keepdim=True)
x = (x - u) / torch.sqrt(s + self.variance_epsilon)
return self.weight * x + self.bias
class SelfAttention(nn.Module):
def __init__(self, num_attention_heads, input_size, hidden_size, attention_probs_dropout_prob, hidden_dropout_prob):
super(SelfAttention, self).__init__()
if hidden_size % num_attention_heads != 0:
raise ValueError(
"The hidden size (%d) is not a multiple of the number of attention "
"heads (%d)" % (hidden_size, num_attention_heads))
self.num_attention_heads = num_attention_heads
self.attention_head_size = int(hidden_size / num_attention_heads)
self.all_head_size = hidden_size
self.query = nn.Linear(input_size, self.all_head_size)
self.key = nn.Linear(input_size, self.all_head_size)
self.value = nn.Linear(input_size, self.all_head_size)
self.attn_dropout = nn.Dropout(attention_probs_dropout_prob)
self.dense = nn.Linear(hidden_size, hidden_size)
self.LayerNorm = LayerNorm(hidden_size, eps=1e-12)
self.out_dropout = nn.Dropout(hidden_dropout_prob)
def transpose_for_scores(self, x):
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(*new_x_shape)
return x.permute(0, 2, 1, 3)
def forward(self, input_tensor):
mixed_query_layer = self.query(input_tensor)
mixed_key_layer = self.key(input_tensor)
mixed_value_layer = self.value(input_tensor)
query_layer = self.transpose_for_scores(mixed_query_layer)
key_layer = self.transpose_for_scores(mixed_key_layer)
value_layer = self.transpose_for_scores(mixed_value_layer)
# Take the dot product between "query" and "key" to get the raw attention scores.
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
# Apply the attention mask is (precomputed for all layers in BertModel forward() function)
# [batch_size heads seq_len seq_len] scores
# [batch_size 1 1 seq_len]
# attention_scores = attention_scores + attention_mask
# Normalize the attention scores to probabilities.
attention_probs = nn.Softmax(dim=-1)(attention_scores)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
# Fixme
attention_probs = self.attn_dropout(attention_probs)
context_layer = torch.matmul(attention_probs, value_layer)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(*new_context_layer_shape)
hidden_states = self.dense(context_layer)
hidden_states = self.out_dropout(hidden_states)
hidden_states = self.LayerNorm(hidden_states + input_tensor)
return hidden_states
hidet.option.cache_dir('./outs/cache')
model = SelfAttention(num_attention_heads = 12, input_size = 768, hidden_size = 768, attention_probs_dropout_prob = 0.5, hidden_dropout_prob = 0.5).cuda().eval()
x = torch.rand(1, 128, 768).cuda()
# print(model)
model_opt = torch.compile(model, backend='hidet')
y = model_opt(x)
Here is the 12nd kernel meta.json:
{
"name": "fused_subtract_pow",
"symbols": [],
"inputs": [
{
"device": "cuda",
"dtype": "float32",
"shape": []
},
{
"device": "cuda",
"dtype": "float32",
"shape": [
1,
128,
768
]
},
{
"device": "cuda",
"dtype": "float32",
"shape": [
1,
128,
1
]
}
],
"outputs": [
{
"device": "cuda",
"dtype": "float32",
"shape": [
1,
128,
768
]
}
],
"target": "cuda",
"num_candidates": 1,
"hidet_version": "0.3.1.dev"
}
Here is the generated kernel:
#include <stdint.h>
#include <hidet/runtime/symbols.h>
#include <hidet/runtime/memory_planner.h>
#include <hidet/runtime/cpu/context.h>
#include <hidet/runtime/cuda/complex.h>
#include <hidet/runtime/cuda/context.h>
#include <hidet/runtime/logging.h>
static __global__ void __launch_bounds__(512) hidet_fused_compute_z(float * __restrict__ x, float * __restrict__ y, float * __restrict__ y_1, float * __restrict__ z) {
z[((((((int)blockIdx.x * 512) + (int)threadIdx.x) / 768) * 768) + ((((int)blockIdx.x * 512) + (int)threadIdx.x) % 768))] = powf((x[((((((int)blockIdx.x * 512) + (int)threadIdx.x) / 768) * 768) + ((((int)blockIdx.x * 512) + (int)threadIdx.x) % 768))] - y[((((int)blockIdx.x * 512) + (int)threadIdx.x) / 768)]), y_1[0]);
}
DLL void hidet_get_input_shape(int32_t idx, int32_t * __restrict__ dims) {
if (idx == 0) {
}
if (idx == 1) {
dims[0] = 1;
dims[1] = 128;
dims[2] = 768;
}
if (idx == 2) {
dims[0] = 1;
dims[1] = 128;
dims[2] = 1;
}
}
DLL void hidet_get_output_shape(int32_t idx, int32_t * __restrict__ dims) {
if (idx == 0) {
dims[0] = 1;
dims[1] = 128;
dims[2] = 768;
}
}
DLL void hidet_launch_0(float * __restrict__ y, float * __restrict__ x, float * __restrict__ y_1, float * __restrict__ z) {
hidet_fused_compute_z<<<dim3(192, 1, 1), dim3(512, 1, 1), 0, (cudaStream_t)get_cuda_stream()>>>(x, y_1, y, z);
{cudaError_t err = cudaGetLastError(); if (err != cudaSuccess) LOG(ERROR) << "CUDA error: " << cudaGetErrorString(err) << "\n";}
}