Cross Attention DCMHA
Closed this issue · 4 comments
Hello! Your achievements and code style are both quite outstanding! It can be seen that you are very proficient in torch and other libraries.
In the example code DCFormer you provided, a self attention mechanism has been implemented. But I want to apply it to cross attention, and I have to say that your implementation is so brilliant that I still cannot master it. Where should I modify to achieve cross attention?
I tried to only change the calculation of qkv (separately), but when i run, i got an Error:
File "transformer_utilities\DynamicMHA.py", line 437, in forward hidden = torch.einsum(eqn1.replace(hidden_sym, ''), inputs,
File "miniconda\envs\cellmemory\lib\site-packages\torch\functional.py", line 377, in einsum return _VF.einsum(equation, operands) # type: ignore[attr-defined]
RuntimeError: einsum(): subscript S has size 8 for operand 1 which does not broadcast with previously seen size 631
I'm sorry, but I'm not very fond of it. I look forward to your advice!
I make it crossable by modify Class CrossHeadProjection
function forward
in DCFormer
Warning: I set model.attribute 'is_training' = False So, I can attention on this class to make it fix easier.
I have no idea about will it influence something...
def forward(self, inputs,
dws: Optional[Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]] = None,
query_vec=None, key_vec=None,
proj_w: Optional[Tensor] = None,
fast_infer=True):
if proj_w is not None:
ret = torch.einsum('BNTS,BSNM->BMTS', inputs, proj_w)
else:
assert dws is not None
qw1, qw2, kw1, kw2, qdd, kdd = dws
inputs = inputs.unsqueeze(1) # BNTS->BGNTS
# apply sw
ret = torch.einsum('BGMTS,GMN->BGNTS', inputs, self.w) if self.use_sw else inputs
if fast_infer:
inputs_label = ['BGMTS', 'BGMST']
hidden_sym = 'I'
hiddens_label = [input_label.replace('M', 'I') for input_label in inputs_label] # BGITS
# apply qw and kw #@todo: huyw fix
for input_label, hidden_label, sym, (w1, w2) in zip(inputs_label, hiddens_label, ['T', 'S'],
[(qw1, qw2), (kw1, kw2)]):
dw_label = f'B{sym}G{hidden_sym}M' # w1: BTGIM, dw_label:BTGIM
dynamic_hidden_dim = w1.shape[dw_label.index(hidden_sym)]
eqn1 = f'{input_label},{dw_label}->{hidden_label}' # 'BGMTS,BTGMI->BGITS'
eqn2 = f'{hidden_label},{dw_label}->{input_label}' # 'BGITS,BTGMI->BGMTS'
for i in range(dynamic_hidden_dim):
# @todo: debug print
hidden = torch.einsum(eqn1.replace(hidden_sym, ''), inputs,
w1[..., i, :]) # BGMTS,BTG(I)M->BGTS
out = torch.einsum(eqn2.replace(hidden_sym, ''), hidden,
w2[..., i, :]) # 'BG(I)TS,BTG(I)M->BGMTS'
ret = ret + out
# apply qdd and kdd
for input_label, sym, dd in zip(inputs_label, ['T', 'S'], [qdd, kdd]):
dd_label = f'B{sym}GM'
dout = torch.einsum(f'{input_label},{dd_label}->{input_label}', inputs,
dd) # BGMTS,B(T/S)GM->BGMTS
ret = ret + dout
else:
# apply qw and kw (BTGIN)
x_inter = torch.einsum('BGNTS, BTGIN->BGTSI', inputs, qw1)
qw_out = torch.einsum('BGTSI, BTGIN->BGNTS', x_inter, qw2)
ret = ret + qw_out
x_inter = torch.einsum('BGNTS, BSGIN->BGTSI', inputs, kw1)
kw_out = torch.einsum('BGTSI, BSGIN->BGNTS', x_inter, kw2)
ret = ret + kw_out
# apply qdd(BTGN) and kdd(BSGN)
ret = ret + torch.einsum('BGNTS, BTGN->BGNTS', inputs, qdd)
ret = ret + torch.einsum('BGNTS, BSGN->BGNTS', inputs, kdd)
ret = ret.squeeze(1) # BGNTS->BNTS
return ret
I modify the 'else' block:
inputs_label = ['BGMTS', 'BGMST']
& hiddens_label = [input_label.replace('M', 'I') for input_label in inputs_label]
AND
for input_label, sym, dd in zip(inputs_label, ['T', 'S'], [qdd, kdd]):
dd_label = f'B{sym}GM'
dout = torch.einsum(f'{input_label},{dd_label}->{input_label}', inputs,
dd) # BGMTS,B(T/S)GM->BGMTS
ret = ret + dout
Then it cross well!
I'm not sure if I modify it right, thanks!
@WhatMelonGua I'm not sure about what you mean by cross attention.
- If it refers to bidirectional attention(bert-like attention), all you need to do is to change causal attention mask to bidirectional attention mask. The reason is that DCMHA transforms attention logits(before softmax) and probabilities(after softmax) with shape of BNTS along the head dimension N, which is orthogonal to attention mask types.
- If it refers to cross attention in encoder-decoder architecture(eg, in the machine translation task), you need to generate dynamic weights of head composition, key-wise weights from encoder hidden states and query-wise weights from decoder hidden states, then compose attention heads.
@WhatMelonGua I'm not sure about what you mean by cross attention.我不太明白你说的交叉关注是什么意思。
- If it refers to bidirectional attention(bert-like attention), all you need to do is to change causal attention mask to bidirectional attention mask. The reason is that DCMHA transforms attention logits(before softmax) and probabilities(after softmax) with shape of BNTS along the head dimension N, which is orthogonal to attention mask types.如果它指的是双向注意(伯特式注意),你所需要做的就是将因果注意力面具改为双向注意力面具。原因是DCMHA沿着头部维度N(其与注意掩码类型正交)以BNTS的形状(沿着)变换注意对数(在softmax之前)和概率(在softmax之后)。
- If it refers to cross attention in encoder-decoder architecture(eg, in the machine translation task), you need to generate dynamic weights of head composition, key-wise weights from encoder hidden states and query-wise weights from decoder hidden states, then compose attention heads.如果它指的是编码器-解码器架构中的交叉注意力(例如,在机器翻译任务中),则需要生成头部组成的动态权重,编码器隐藏状态的键式权重和解码器隐藏状态的查询式权重,然后组成注意力头部。
@hilbertmeng
Sorry I dont describe it clearly, and yes it refers to cross attention in encoder-decoder:
Now, DCMHA accept a input x, as qkv 3 object (Q, K, V)
The Cross Attention I mean is that, Query & Key as inputs, then Q attention to K and Value from K
I keep the question to query if I make a right modify, and Thank you for your suggestions & knowledge!
@WhatMelonGua I write a high-level pseudocode of implementing Cross DCMHA for reference. Hope it may help.
def cross_attention_dcmha(encoder_hidden, decoder_hidden): # encoder_hidden: BTD, decoder_hidden: BSD
# layer normalization
encoder_hidden = layer_norm(encoder_hidden)
decoder_hidden = layer_norm(decoder_hidden)
# qkv projection
query = project_query(decoder_hidden) # BTD, DNd -> BNTd
key = project_key(encoder_hidden) # BSD, DNd -> BNSd
value = project_value(encoder_hidden) # BSD, DNd -> BNSd
# ignore position embedding
# ...
# generate dynamic weights of attention head composition
pre_kw, post_kw = generate_dw(encoder_hidden) # BSD -> BSNM
pre_qw, post_qw = generate_dw(decoder_hidden) # BTD -> BTNM
# pre-compose
attention_logits = query @ key # BNTd, BNSd -> BNTS
attention_logits_kw_contrib = pre_kw @ attention_logits # BSNM, BNTS -> BMTS
attention_logits_qw_contrib = pre_qw @ attention_logits # BTNM, BNTS -> BMTS
attention_logits = attention_logits + attention_logits_kw_contrib + attention_logits_qw_contrib # BNTS
# softmax
attention_probs = softmax(attention_logits, dim=-1)
# post-compose
attention_probs_kw_contrib = post_kw @ attention_probs # BSNM, BNTS -> BMTS
attention_probs_qw_contrib = post_qw @ attention_probs # BTNM, BNTS -> BMTS
attention_probs = attention_probs + attention_probs_kw_contrib + attention_probs_qw_contrib # BNTS
mixed_value = attention_probs @ value # BNTS, BNSD -> BNTD
out = project_out(mixed_value) # BNTd, dD -> BTD
return out
You can refer to https://github.com/Caiyun-AI/DCFormer/blob/main/pytorch/dcformer/modeling_dcformer.py#L230-L246 for generate_dw
implementation. For simplicity, qw
is short for qw1 @ qw2 + diag_embed(qdd)
in the pseudocode, but it's inefficient. The same holds true for kw
.
Just a kind reminder: if you want to train DCMHA models in pytorch
, keep it in mind to accelerate training with torch.compile
.