keras-team/tf-keras

Using `layers.MultiHeadAttention` increases parameter `num_heads` times

Closed this issue · 2 comments

System information.

  • Have I written custom code (as opposed to using a stock example script provided in Keras):
  • OS Platform and Distribution (e.g., Linux Ubuntu 16.04):
  • TensorFlow installed from (source or binary):
  • TensorFlow version (use command below): 2.12
  • Python version:
  • Bazel version (if compiling from source):
  • GPU model and memory:
  • Exact command to reproduce:

Describe the problem

I am building a model using tensorflow and torch; at some point I need to use MultiHeadAttention layer. While ensuring the implementation details, I've encountered a case where I noticed using keras.layers.MultiHeadAttention increased model parameters with num_heads times; while torch.nn.MultiheadAttention remains same irrespective of num_heads.

Describe the expected behavior

I have read some explanation from this link. Quoting

Correspondence between weights in tf.keras.layers.MultiHeadAttention and nn.MultiheadAttention not so clear, as an example: torch shares weights between heads, while tf keeps them unique. So if you are using weights of pretrained model from pytorch and try to put them in tensorflow model (for whatever reason) it'll certainly take more than five minutes.

It looks like keras.MHA nnecessary keeps them unique even if that's not needed.

Standalone code to reproduce the issue

The parameter count in tf increases by n_head times.

In torch,

import torch
import torch.nn as nn

class TorchAttentionModel(nn.Module):
    def __init__(self, d_model, n_head):
        super().__init__()
        self.attn = nn.MultiheadAttention(d_model, n_head)
    def forward(self, x):
        return self.attn(x, x, x)

torch_model = TorchAttentionModel(d_model=32, n_head=8)
dummy_input = torch.randn(10, 16, 32)  # seq len, bsize, feat size
torch_output = torch_model(dummy_input)
torch_params = sum(p.numel() for p in torch_model.parameters() if p.requires_grad)

# --------------------------------
print("PyTorch, number of params:", torch_params)
PyTorch, number of params: 4224 # <-------
import tensorflow as tf
from tensorflow.keras import layers

class TFAttentionModel(tf.keras.Model):
    def __init__(self, d_model, n_head):
        super().__init__()
        self.attn = layers.MultiHeadAttention(num_heads=n_head, key_dim=d_model)
    def call(self, x):
        return self.attn(x, x)


tf_model = TFAttentionModel(d_model=32, n_head=8)
dummy_input = tf.random.normal([16, 10, 32])  # bsize, seq len, feat size
tf_output = tf_model(dummy_input)
tf_params = tf_model.count_params()

# --------------------------------
print("TensorFlow, number of params:", tf_params)
TensorFlow, number of params: 33568 # <-------

The num_head is 8 here and compare to torch and tf count parameters:


tf_params / torch_params
7.946969696969697

@sachinprasadhs,
I was able to reproduce the issue on tensorflow v2.1, v2.12 and tf-nightly. Kindly find the gist of it here.

@tilakrayal Thanks for the checks. I've found the mismatch issue.

In keras.layers.MultiHeadAttention(num_heads, key_dim, ...) , it's said

num_heads: Number of attention heads.
key_dim: Size of each attention head for query and key.

In torch.nn.MultiheadAttention(embed_dim, num_heads, ..

embed_dimTotal dimension of the model.
num_headsNumber of parallel attention heads. 
            Note that embed_dim will be split across num_heads 
           (i.e. each head will have dimension embed_dim // num_heads).

So, in keras, instead of TFAttentionModel(d_model=32, n_head=8), it would be
TFAttentionModel(d_model=32 // 8, n_head=8).