Using `layers.MultiHeadAttention` increases parameter `num_heads` times
Closed this issue · 2 comments
System information.
- Have I written custom code (as opposed to using a stock example script provided in Keras):
- OS Platform and Distribution (e.g., Linux Ubuntu 16.04):
- TensorFlow installed from (source or binary):
- TensorFlow version (use command below): 2.12
- Python version:
- Bazel version (if compiling from source):
- GPU model and memory:
- Exact command to reproduce:
Describe the problem
I am building a model using tensorflow and torch; at some point I need to use MultiHeadAttention
layer. While ensuring the implementation details, I've encountered a case where I noticed using keras.layers.MultiHeadAttention
increased model parameters with num_heads
times; while torch.nn.MultiheadAttention
remains same irrespective of num_heads
.
Describe the expected behavior
I have read some explanation from this link. Quoting
Correspondence between weights in tf.keras.layers.MultiHeadAttention and nn.MultiheadAttention not so clear, as an example: torch shares weights between heads, while tf keeps them unique. So if you are using weights of pretrained model from pytorch and try to put them in tensorflow model (for whatever reason) it'll certainly take more than five minutes.
It looks like keras.MHA nnecessary keeps them unique even if that's not needed.
Standalone code to reproduce the issue
The parameter count in tf
increases by n_head
times.
In torch
,
import torch
import torch.nn as nn
class TorchAttentionModel(nn.Module):
def __init__(self, d_model, n_head):
super().__init__()
self.attn = nn.MultiheadAttention(d_model, n_head)
def forward(self, x):
return self.attn(x, x, x)
torch_model = TorchAttentionModel(d_model=32, n_head=8)
dummy_input = torch.randn(10, 16, 32) # seq len, bsize, feat size
torch_output = torch_model(dummy_input)
torch_params = sum(p.numel() for p in torch_model.parameters() if p.requires_grad)
# --------------------------------
print("PyTorch, number of params:", torch_params)
PyTorch, number of params: 4224 # <-------
import tensorflow as tf
from tensorflow.keras import layers
class TFAttentionModel(tf.keras.Model):
def __init__(self, d_model, n_head):
super().__init__()
self.attn = layers.MultiHeadAttention(num_heads=n_head, key_dim=d_model)
def call(self, x):
return self.attn(x, x)
tf_model = TFAttentionModel(d_model=32, n_head=8)
dummy_input = tf.random.normal([16, 10, 32]) # bsize, seq len, feat size
tf_output = tf_model(dummy_input)
tf_params = tf_model.count_params()
# --------------------------------
print("TensorFlow, number of params:", tf_params)
TensorFlow, number of params: 33568 # <-------
The num_head
is 8
here and compare to torch
and tf
count parameters:
tf_params / torch_params
7.946969696969697
@sachinprasadhs,
I was able to reproduce the issue on tensorflow v2.1, v2.12 and tf-nightly. Kindly find the gist of it here.
@tilakrayal Thanks for the checks. I've found the mismatch issue.
In keras.layers.MultiHeadAttention(num_heads, key_dim, ...)
, it's said
num_heads: Number of attention heads.
key_dim: Size of each attention head for query and key.
In torch.nn.MultiheadAttention(embed_dim, num_heads, ..
embed_dim – Total dimension of the model.
num_heads – Number of parallel attention heads.
Note that embed_dim will be split across num_heads
(i.e. each head will have dimension embed_dim // num_heads).
So, in keras, instead of TFAttentionModel(d_model=32, n_head=8)
, it would be
TFAttentionModel(d_model=32 // 8, n_head=8)
.