Approximated values are off
jglaser opened this issue · 2 comments
jglaser commented
I wrote a simple test to check the output of the hierarchical transformer self attention against the BERT self attention from huggingface transformers.
import torch
import torch.nn as nn
import math
from h_transformer_1d.h_transformer_1d import HAttention1D
def transpose_for_scores(x, num_attention_heads, attention_head_size):
new_x_shape = x.size()[:-1] + (num_attention_heads, attention_head_size)
x = x.view(*new_x_shape)
return x.permute(0, 2, 1, 3)
def bert_self_attention(query, key, value_layer, attention_mask=None, num_attention_heads=1):
dim_head = query.size()[-1] // num_attention_heads
all_head_size = dim_head*num_attention_heads
query_layer = transpose_for_scores(query, num_attention_heads, dim_head)
key_layer = transpose_for_scores(key, num_attention_heads, dim_head)
value_layer = transpose_for_scores(value, num_attention_heads, dim_head)
# Take the dot product between "query" and "key" to get the raw attention scores.
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
attention_scores = attention_scores / math.sqrt(dim_head)
if attention_mask is not None:
# Apply the attention mask is (precomputed for all layers in BertModel forward() function)
attention_scores = attention_scores + attention_mask
# Normalize the attention scores to probabilities.
attention_probs = nn.functional.softmax(attention_scores, dim=-1)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
#attention_probs = self.dropout(attention_probs)
context_layer = torch.matmul(attention_probs, value_layer)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (all_head_size,)
context_layer = context_layer.view(*new_context_layer_shape)
return context_layer, attention_probs
if __name__ == "__main__":
query = torch.tensor([[[0.1,0.2],[-0.5,0.7],[-0.5,-0.75],[.123,.456]]])
# query = torch.tensor([[[0.1,0.2],[-0.5,0.7]]])
key = value = query
n_heads = 1
attn, probs = bert_self_attention(query, key, value, num_attention_heads=n_heads)
print('bert_self_attention out: ', attn)
block_size = 1
for _ in range(0,2):
dim_head = query.size()[-1]//n_heads
h_attn = HAttention1D(
dim=query.size()[-1],
heads=n_heads,
dim_head=dim_head,
block_size=block_size
)
h_attn.to_qkv = torch.nn.Identity()
h_attn.to_out = torch.nn.Identity()
qkv = torch.stack([query, key, value], dim=2)
qkv = torch.flatten(qkv, start_dim=2)
attn_scores = h_attn(qkv)
print('hattention_1d: (block_size = {})'.format(block_size), attn_scores)
block_size *= 2
This is the output I get
bert_self_attention: tensor([[[-0.1807, 0.1959],
[-0.2096, 0.2772],
[-0.2656, -0.0568],
[-0.1725, 0.2442]]])
hattention_1d: (block_size = 1) tensor([[[-0.2000, 0.4500],
[-0.2000, 0.4500],
[-0.1885, -0.1470],
[-0.1885, -0.1470]]])
before it errors out with
assert num_levels >= 0, 'number of levels must be at least greater than 0'
Some of the values are off in absolute magnitude by more than a factor of two.
Looking at the code, this line seems problematic:
I believe it should read
num_levels = int(log2(pad_to_len // bsz)) - 1
If I make that change, the approximated attention output is much closer to the exact one:
bert_self_attention out: tensor([[[-0.1807, 0.1959],
[-0.2096, 0.2772],
[-0.2656, -0.0568],
[-0.1725, 0.2442]]])
hattention_1d: (block_size = 1) tensor([[[-0.2590, 0.2020],
[-0.2590, 0.2020],
[-0.2590, 0.2020],
[-0.2590, 0.2020]]])
hattention_1d: (block_size = 2) tensor([[[-0.1808, 0.1972],
[-0.1980, 0.2314],
[-0.2438, 0.0910],
[-0.1719, 0.2413]]])