Caiyun-AI/DCFormer

Cross Attention DCMHA

Closed this issue · 4 comments

Hello! Your achievements and code style are both quite outstanding! It can be seen that you are very proficient in torch and other libraries.

In the example code DCFormer you provided, a self attention mechanism has been implemented. But I want to apply it to cross attention, and I have to say that your implementation is so brilliant that I still cannot master it. Where should I modify to achieve cross attention?

I tried to only change the calculation of qkv (separately), but when i run, i got an Error:

File "transformer_utilities\DynamicMHA.py", line 437, in forward hidden = torch.einsum(eqn1.replace(hidden_sym, ''), inputs,

File "miniconda\envs\cellmemory\lib\site-packages\torch\functional.py", line 377, in einsum return _VF.einsum(equation, operands) # type: ignore[attr-defined]
RuntimeError: einsum(): subscript S has size 8 for operand 1 which does not broadcast with previously seen size 631

I'm sorry, but I'm not very fond of it. I look forward to your advice!

I make it crossable by modify Class CrossHeadProjection function forward in DCFormer

Warning: I set model.attribute 'is_training' = False So, I can attention on this class to make it fix easier.
I have no idea about will it influence something...

    def forward(self, inputs,
                dws: Optional[Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]] = None,
                query_vec=None, key_vec=None,
                proj_w: Optional[Tensor] = None,
                fast_infer=True):
        if proj_w is not None:
            ret = torch.einsum('BNTS,BSNM->BMTS', inputs, proj_w)
        else:
            assert dws is not None
            qw1, qw2, kw1, kw2, qdd, kdd = dws
            inputs = inputs.unsqueeze(1)  # BNTS->BGNTS
            # apply sw
            ret = torch.einsum('BGMTS,GMN->BGNTS', inputs, self.w) if self.use_sw else inputs
            if fast_infer:
                inputs_label = ['BGMTS', 'BGMST']
                hidden_sym = 'I'
                hiddens_label = [input_label.replace('M', 'I') for input_label in inputs_label]  # BGITS
                # apply qw and kw #@todo: huyw fix
                for input_label, hidden_label, sym, (w1, w2) in zip(inputs_label, hiddens_label, ['T', 'S'],
                                                                    [(qw1, qw2), (kw1, kw2)]):
                    dw_label = f'B{sym}G{hidden_sym}M'  # w1: BTGIM, dw_label:BTGIM
                    dynamic_hidden_dim = w1.shape[dw_label.index(hidden_sym)]
                    eqn1 = f'{input_label},{dw_label}->{hidden_label}'  # 'BGMTS,BTGMI->BGITS'
                    eqn2 = f'{hidden_label},{dw_label}->{input_label}'  # 'BGITS,BTGMI->BGMTS'
                    for i in range(dynamic_hidden_dim):
                        # @todo: debug print
                        hidden = torch.einsum(eqn1.replace(hidden_sym, ''), inputs,
                                              w1[..., i, :])  # BGMTS,BTG(I)M->BGTS
                        out = torch.einsum(eqn2.replace(hidden_sym, ''), hidden,
                                           w2[..., i, :])  # 'BG(I)TS,BTG(I)M->BGMTS'
                        ret = ret + out
                # apply qdd and kdd
                for input_label, sym, dd in zip(inputs_label, ['T', 'S'], [qdd, kdd]):
                    dd_label = f'B{sym}GM'
                    dout = torch.einsum(f'{input_label},{dd_label}->{input_label}', inputs,
                                        dd)  # BGMTS,B(T/S)GM->BGMTS
                    ret = ret + dout
            else:
                # apply qw and kw (BTGIN)
                x_inter = torch.einsum('BGNTS, BTGIN->BGTSI', inputs, qw1)
                qw_out = torch.einsum('BGTSI, BTGIN->BGNTS', x_inter, qw2)
                ret = ret + qw_out
                x_inter = torch.einsum('BGNTS, BSGIN->BGTSI', inputs, kw1)
                kw_out = torch.einsum('BGTSI, BSGIN->BGNTS', x_inter, kw2)
                ret = ret + kw_out

                # apply qdd(BTGN) and kdd(BSGN)
                ret = ret + torch.einsum('BGNTS, BTGN->BGNTS', inputs, qdd)
                ret = ret + torch.einsum('BGNTS, BSGN->BGNTS', inputs, kdd)
            ret = ret.squeeze(1)  # BGNTS->BNTS
        return ret

I modify the 'else' block:
inputs_label = ['BGMTS', 'BGMST'] & hiddens_label = [input_label.replace('M', 'I') for input_label in inputs_label]
AND

                for input_label, sym, dd in zip(inputs_label, ['T', 'S'], [qdd, kdd]):
                    dd_label = f'B{sym}GM'
                    dout = torch.einsum(f'{input_label},{dd_label}->{input_label}', inputs,
                                        dd)  # BGMTS,B(T/S)GM->BGMTS
                    ret = ret + dout

Then it cross well!

I'm not sure if I modify it right, thanks!

@WhatMelonGua I'm not sure about what you mean by cross attention.

  • If it refers to bidirectional attention(bert-like attention), all you need to do is to change causal attention mask to bidirectional attention mask. The reason is that DCMHA transforms attention logits(before softmax) and probabilities(after softmax) with shape of BNTS along the head dimension N, which is orthogonal to attention mask types.
  • If it refers to cross attention in encoder-decoder architecture(eg, in the machine translation task), you need to generate dynamic weights of head composition, key-wise weights from encoder hidden states and query-wise weights from decoder hidden states, then compose attention heads.

@WhatMelonGua I'm not sure about what you mean by cross attention.我不太明白你说的交叉关注是什么意思。

  • If it refers to bidirectional attention(bert-like attention), all you need to do is to change causal attention mask to bidirectional attention mask. The reason is that DCMHA transforms attention logits(before softmax) and probabilities(after softmax) with shape of BNTS along the head dimension N, which is orthogonal to attention mask types.如果它指的是双向注意(伯特式注意),你所需要做的就是将因果注意力面具改为双向注意力面具。原因是DCMHA沿着头部维度N(其与注意掩码类型正交)以BNTS的形状(沿着)变换注意对数(在softmax之前)和概率(在softmax之后)。
  • If it refers to cross attention in encoder-decoder architecture(eg, in the machine translation task), you need to generate dynamic weights of head composition, key-wise weights from encoder hidden states and query-wise weights from decoder hidden states, then compose attention heads.如果它指的是编码器-解码器架构中的交叉注意力(例如,在机器翻译任务中),则需要生成头部组成的动态权重,编码器隐藏状态的键式权重和解码器隐藏状态的查询式权重,然后组成注意力头部。

@hilbertmeng
Sorry I dont describe it clearly, and yes it refers to cross attention in encoder-decoder:
Now, DCMHA accept a input x, as qkv 3 object (Q, K, V)
The Cross Attention I mean is that, Query & Key as inputs, then Q attention to K and Value from K

I keep the question to query if I make a right modify, and Thank you for your suggestions & knowledge!

@WhatMelonGua I write a high-level pseudocode of implementing Cross DCMHA for reference. Hope it may help.

def cross_attention_dcmha(encoder_hidden, decoder_hidden): # encoder_hidden: BTD, decoder_hidden: BSD
    # layer normalization
    encoder_hidden = layer_norm(encoder_hidden)
    decoder_hidden = layer_norm(decoder_hidden)

    # qkv projection  
    query = project_query(decoder_hidden) # BTD, DNd -> BNTd
    key = project_key(encoder_hidden) # BSD, DNd -> BNSd
    value = project_value(encoder_hidden) # BSD, DNd -> BNSd

    # ignore position embedding 
    # ...

    # generate dynamic weights of attention head composition
    pre_kw, post_kw = generate_dw(encoder_hidden) # BSD -> BSNM
    pre_qw, post_qw = generate_dw(decoder_hidden) # BTD -> BTNM

    # pre-compose
    attention_logits = query @ key # BNTd, BNSd -> BNTS
    attention_logits_kw_contrib = pre_kw @ attention_logits # BSNM, BNTS -> BMTS
    attention_logits_qw_contrib = pre_qw @ attention_logits # BTNM, BNTS -> BMTS
    attention_logits = attention_logits + attention_logits_kw_contrib + attention_logits_qw_contrib # BNTS

    # softmax
    attention_probs = softmax(attention_logits, dim=-1) 

    # post-compose 
    attention_probs_kw_contrib = post_kw @ attention_probs # BSNM, BNTS -> BMTS
    attention_probs_qw_contrib = post_qw @ attention_probs # BTNM, BNTS -> BMTS
    attention_probs = attention_probs + attention_probs_kw_contrib + attention_probs_qw_contrib # BNTS

    mixed_value = attention_probs @ value # BNTS, BNSD -> BNTD
    out = project_out(mixed_value) #  BNTd, dD -> BTD
    return out

You can refer to https://github.com/Caiyun-AI/DCFormer/blob/main/pytorch/dcformer/modeling_dcformer.py#L230-L246 for generate_dw implementation. For simplicity, qw is short for qw1 @ qw2 + diag_embed(qdd) in the pseudocode, but it's inefficient. The same holds true for kw.
Just a kind reminder: if you want to train DCMHA models in pytorch, keep it in mind to accelerate training with torch.compile.