Don't you think this is more legible?
ZIZUN opened this issue · 0 comments
ZIZUN commented
`
class SpatialGatingUnit(nn.Module):
def init(self, dim, dim_seq, causal = False, act = nn.Identity(), init_eps = 1e-3):
super().init()
dim_out = dim // 2
self.causal = causal
self.norm = nn.LayerNorm(dim_out)
#self.proj = nn.Conv1d(dim_seq, dim_seq, 1)
self.dim_seq = dim_seq
self.w_ = nn.Parameter(torch.zeros(dim_seq, dim_seq), requires_grad=True) ####
self.b_ = nn.Parameter(torch.ones(dim_seq), requires_grad=True) ####
self.act = act
init_eps /= dim_seq
#nn.init.uniform_(self.proj.weight, -init_eps, init_eps)
#nn.init.constant_(self.proj.bias, 1.)
def forward(self, x, gate_res = None): # x -> bsz, len, hidden*6
device, n = x.device, x.shape[1]
res, gate = x.chunk(2, dim = -1)
gate = self.norm(gate)
weight, bias = self.w_, self.b_ # weight -> len, len, 1 bias -> len
if self.causal:
weight.unsqueeze(-1) # TODO
weight, bias = weight[:n, :n], bias[:n]
mask = torch.ones(weight.shape[:2], device = device).triu_(1).bool()
weight = weight.masked_fill(mask[..., None], 0.)
weight.squeeze(-1)# TODO
gate = torch.matmul(weight, gate) + bias[None, :self.dim_seq, None] # WZ + b
#gate = F.conv1d(gate, weight, bias) # WZ + b
if exists(gate_res):
gate = gate + gate_res
return self.act(gate) * res
`