lucidrains/h-transformer-1d

Approximated values are off

jglaser opened this issue · 2 comments

I wrote a simple test to check the output of the hierarchical transformer self attention against the BERT self attention from huggingface transformers.

import torch
import torch.nn as nn
import math

from h_transformer_1d.h_transformer_1d import HAttention1D

def transpose_for_scores(x, num_attention_heads, attention_head_size):
    new_x_shape = x.size()[:-1] + (num_attention_heads, attention_head_size)
    x = x.view(*new_x_shape)
    return x.permute(0, 2, 1, 3)

def bert_self_attention(query, key, value_layer, attention_mask=None, num_attention_heads=1):
        dim_head = query.size()[-1] // num_attention_heads
        all_head_size = dim_head*num_attention_heads

        query_layer = transpose_for_scores(query, num_attention_heads, dim_head)
        key_layer = transpose_for_scores(key, num_attention_heads, dim_head)
        value_layer = transpose_for_scores(value, num_attention_heads, dim_head)

        # Take the dot product between "query" and "key" to get the raw attention scores.
        attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))

        attention_scores = attention_scores / math.sqrt(dim_head)

        if attention_mask is not None:
            # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
            attention_scores = attention_scores + attention_mask

        # Normalize the attention scores to probabilities.
        attention_probs = nn.functional.softmax(attention_scores, dim=-1)

        # This is actually dropping out entire tokens to attend to, which might
        # seem a bit unusual, but is taken from the original Transformer paper.
        #attention_probs = self.dropout(attention_probs)

        context_layer = torch.matmul(attention_probs, value_layer)

        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
        new_context_layer_shape = context_layer.size()[:-2] + (all_head_size,)
        context_layer = context_layer.view(*new_context_layer_shape)

        return context_layer, attention_probs

if __name__ == "__main__":
    query = torch.tensor([[[0.1,0.2],[-0.5,0.7],[-0.5,-0.75],[.123,.456]]])
#    query = torch.tensor([[[0.1,0.2],[-0.5,0.7]]])
    key = value = query

    n_heads = 1
    attn, probs = bert_self_attention(query, key, value, num_attention_heads=n_heads)
    print('bert_self_attention out: ', attn)

    block_size = 1
    for _ in range(0,2):
        dim_head = query.size()[-1]//n_heads
        h_attn = HAttention1D(
            dim=query.size()[-1],
            heads=n_heads,
            dim_head=dim_head,
            block_size=block_size
        )

        h_attn.to_qkv = torch.nn.Identity()
        h_attn.to_out = torch.nn.Identity()

        qkv = torch.stack([query, key, value], dim=2)
        qkv = torch.flatten(qkv, start_dim=2)

        attn_scores = h_attn(qkv)
        print('hattention_1d: (block_size = {})'.format(block_size), attn_scores)

        block_size *= 2

This is the output I get

bert_self_attention:  tensor([[[-0.1807,  0.1959],
         [-0.2096,  0.2772],
         [-0.2656, -0.0568],
         [-0.1725,  0.2442]]])
hattention_1d: (block_size = 1) tensor([[[-0.2000,  0.4500],
         [-0.2000,  0.4500],
         [-0.1885, -0.1470],
         [-0.1885, -0.1470]]])

before it errors out with

assert num_levels >= 0, 'number of levels must be at least greater than 0'

Some of the values are off in absolute magnitude by more than a factor of two.

Looking at the code, this line seems problematic:

num_levels = int(log2(pad_to_len // bsz)) - 2

I believe it should read

num_levels = int(log2(pad_to_len // bsz)) - 1

If I make that change, the approximated attention output is much closer to the exact one:

bert_self_attention out:  tensor([[[-0.1807,  0.1959],
         [-0.2096,  0.2772],
         [-0.2656, -0.0568],
         [-0.1725,  0.2442]]])
hattention_1d: (block_size = 1) tensor([[[-0.2590,  0.2020],
         [-0.2590,  0.2020],
         [-0.2590,  0.2020],
         [-0.2590,  0.2020]]])
hattention_1d: (block_size = 2) tensor([[[-0.1808,  0.1972],
         [-0.1980,  0.2314],
         [-0.2438,  0.0910],
         [-0.1719,  0.2413]]])

PR #21 fixes a few (last?) outstanding bugs in the calculation

Can you also take a look at the #22? This may be relevant.