berlino/gated_linear_attention

advice for small sized GLA

Closed this issue · 3 comments

Thank you for this amazing work,

I'm trying to include your work as a drop-in replacement of some other SSM such as Mamba and RWKV. Note that I train significantly smaller models (from 20M to 60M params), not related to natural language generation. However I got encouraging results and I believe GLA should be competitive, but so far I fail to match RWKV/Mamba, despite promising speed/VRAM usage.

I have multiple question in order to integrate GLA correctly :

  1. What is your advice on parameters choice for scaling GLA Transformer down to ~20-60M parameters ? In terms of layers/dimension/heads ?
  2. Can you confirm that my interpretation of a GLABlock is correct ?
class SwiGLU(nn.Module):
    def __init__(self, d_model):
        super().__init__()
        self.p_in = nn.Linear(d_model, d_model*8//3)
        self.p_out = nn.Linear(d_model*4//3, d_model)
    def forward(self, x):
        gate, x = self.p_in(x).chunk(2, dim=-1)
        return  self.p_out(nn.functional.silu(gate) * x)

class GLABlock(nn.Module):
    def __init__(self, d_model, heads):
        super().__init__()
        self.att = GatedLinearAttention(d_model, heads)
        self.ffn = SwiGLU(d_model)
        self.ln = nn.LayerNorm(d_model)

    def forward(self, x):
        y = self.att(self.ln(x)) + x
        y = self.ffn(y)
        return y

self.ffn(y) should be self.ffn(self.ln2(y)).
What is the performance gap? For smaller model, the D_model is small, and then D_head is small. We would suggest using a small number of head, e.g., 1, to make D_head >= 64.

self.ffn(y) should be self.ffn(self.ln2(y)).

I didn't get it. Interestingly this alone seems to make it competitive. I will keep investigating now.

Capture d’écran du 2023-12-17 19-35-40

Thank you.

Seems that we missed something in the paper. I checked our code implementation and it has two layernorms like Transformers. I am interested in the exact number on the performance gap. Our smallest experimental scale is 350m. My general sense is that for smaller model, token mixing is more important, so you might want to try out the parameter allocation used in RetNet, i.e., D_k = D_model. D_v=2*D_model. Our finding is that for large-scale model, allocating more parameters to FFNs is more important.