lucidrains/g-mlp-pytorch

Don't you think this is more legible?

ZIZUN opened this issue · 0 comments

ZIZUN commented

`
class SpatialGatingUnit(nn.Module):
def init(self, dim, dim_seq, causal = False, act = nn.Identity(), init_eps = 1e-3):
super().init()
dim_out = dim // 2
self.causal = causal

    self.norm = nn.LayerNorm(dim_out)
    #self.proj = nn.Conv1d(dim_seq, dim_seq, 1)

    self.dim_seq = dim_seq
    self.w_ = nn.Parameter(torch.zeros(dim_seq, dim_seq), requires_grad=True)   ####
    self.b_ = nn.Parameter(torch.ones(dim_seq), requires_grad=True)  ####

    self.act = act

    init_eps /= dim_seq
    #nn.init.uniform_(self.proj.weight, -init_eps, init_eps)
    #nn.init.constant_(self.proj.bias, 1.)

def forward(self, x, gate_res = None): # x -> bsz, len, hidden*6
    device, n = x.device, x.shape[1]

    res, gate = x.chunk(2, dim = -1)
    gate = self.norm(gate)

    weight, bias = self.w_, self.b_ # weight -> len, len, 1     bias -> len

    if self.causal:
        weight.unsqueeze(-1) # TODO
        weight, bias = weight[:n, :n], bias[:n]
        mask = torch.ones(weight.shape[:2], device = device).triu_(1).bool()
        weight = weight.masked_fill(mask[..., None], 0.)
        weight.squeeze(-1)# TODO

    gate = torch.matmul(weight, gate) + bias[None, :self.dim_seq, None]   # WZ + b

    #gate = F.conv1d(gate, weight, bias)   # WZ + b

    if exists(gate_res):
        gate = gate + gate_res

    return self.act(gate) * res

`