glassroom/heinsen_routing

Heinsen routing for NMT task - unexpected results

aimanmutasem opened this issue · 1 comments

Dear Mr. Heinsen,

Thank you so much for the magnificent article and its simple implementation. I have enjoyed reading the article several times.

I'm trying to apply your heinsen routing machines with the NMT task with TRANSFORMER-BASE architecture.

The idea is to use the output of each head's attention as an input capsule at the (encoder and decoder) side. Then concatenate the output couples and forwarded to the up layer which is feed-forward network (FFN) layer.

The problem is the results are poor and the delay time increased by 500% (I think because routing process implementation in the CPU).

I don't know where I messed and I think there is something wrong during the implementation.

I hope to support me to find the optimal way to apply your machines to my project.

Sincerely,
Aiman

Please, take a look at the implementation code and the attached diagram.

`class MultiHeadAttentionLayer(nn.Module):

 def __init__(self, hid_dim, n_heads, dropout, device):

    super().__init__()
   
    assert hid_dim % n_heads == 0
    self.hid_dim = hid_dim
    self.n_heads = n_heads
    self.head_dim = hid_dim // n_heads
    
    self.fc_q = nn.Linear(hid_dim, hid_dim)
    self.fc_k = nn.Linear(hid_dim, hid_dim)
    self.fc_v = nn.Linear(hid_dim, hid_dim)
    
    self.fc_o = nn.Linear(hid_dim, hid_dim)
    
    self.dropout = nn.Dropout(dropout)
    
    self.scale = torch.sqrt(torch.FloatTensor([self.head_dim])).to(device)
    
def forward(self, query, key, value, mask = None):
    
    batch_size = query.shape[0]
    query_len = query.shape[1]
    #query = [batch size, query len, hid dim]
    #key = [batch size, key len, hid dim]
    #value = [batch size, value len, hid dim]
            
    Q = self.fc_q(query)
    K = self.fc_k(key)
    V = self.fc_v(value)
    
    #Q = [batch size, query len, hid dim]
    #K = [batch size, key len, hid dim]
    #V = [batch size, value len, hid dim]
            
    Q = Q.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3)
    K = K.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3)
    V = V.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3)
    
    #Q = [batch size, n heads, query len, head dim]
    #K = [batch size, n heads, key len, head dim]
    #V = [batch size, n heads, value len, head dim]
            
    energy = torch.matmul(Q, K.permute(0, 1, 3, 2)) / self.scale
    
    #energy = [batch size, n heads, query len, key len]
    
    if mask is not None:
        energy = energy.masked_fill(mask == 0, -1e10)
    
    attention = torch.softmax(energy, dim = -1)
            
    #attention = [batch size, n heads, query len, key len]
            
    x = torch.matmul(self.dropout(attention), V)
    
    #x = [batch size, n heads, query len, head dim]
   
    #--------------------------- Capsulation layer -------------------------------------------
    
    mu = x.permute(0, 2, 1, 3).contiguous().to(torch.float)  
    
    #mu = [batch size, query len, n heads, head dim]
    
             
    a = mask.to(torch.float) 
    
    #a = Mask_shape for source = torch.Size([[batch size, 1, 1, query len])
    #a = Mask_shape for target = torch.Size([[batch size, 1, query len, query len])
    
    #------- Trying to adjust the target mask to be same as the source mask shape -----------
    
    if( a.shape[2] == a.shape[3]):
        a = torch.mean(a, 2, True)
        a = a.unsqueeze(2)
        
        
    #a  = torch.Size([[batch size, query len]) - source and target have the same mask shape
    a = a.view(a.shape[0], -1) 
    
    #----------- End of mask shape adjustment ----------------------------------------------

    # d_cov = n heads
    # d_input = head_dim is the embedding dimension for a head = (Embedding dimension) / number of heads
    # d_output = head_dim is the embedding dimension for a head = (Embedding dimension) / number of heads
    # n_out: query len
        
    m = Routing(d_cov=8, d_inp=32, d_out=32, n_out=29)
    
    a_out, mu_out, sig2_out = m(a.cpu(), mu.cpu())
          
    mu_out = mu_out.permute(0, 2, 1, 3).contiguous() 
    
    #mu_out = [batch size, query len, n heads, head dim]
   
    mu_out = mu_out.view(batch_size, -1, self.hid_dim)
    
    #mu_out = [batch size, query len, hid dim]

    #----------------------------------------------------------------------
    
    
    x = self.fc_o(mu_out.to(device))

    return x, attention`

Decoder-Page-1 (1)

Thanks for the kinds words.

Your code is instantiating a new Pytorch nn.Module with randomly initialized weights on each forward pass. As with every other Pytorch nn.Module, you must instantiate the routing module at initialization (e.g., self.my_routing_module = Routing(...)) then use it in the forward pass (e.g., a_out, mu_out, sig2_out = self.my_routing_module(...)) so it can learn during training.

If you have more questions about the basic usage of Pytorch, please ask them in a Pytorch forum, not here.

Good luck!