r-three/phatgoose

Question about the framwork

Opened this issue · 5 comments

I want to understand the overall framework technically. From my understanding and from reading the paper, you are training vectors of shape [input_hidden, 1]—for example, [384, 1] in the case of ViT Small—for each dataset over a few iterations. Then, at inference time, you stack them into a vector of size [384, N], where N is the number of datasets or domains. In this setup, sparse MoE is used, and self.w_gate is replaced with this vector. If I'm wrong in my understanding, please help me correct it.

Hey! that's correct. Additionally, when training the gate vectors, you need to freeze entire model with adapter included. During inference, once you stack the gate vectors to make into a linear layer, make sure to do normalization to account for the fact that these gates are independent and can have varying norms. In papers, we normalized to mean zero and standard deviation of 1.

Thank you for your response. There’s one thing left: is there any ablation on training the vectors you mentioned in the paper? You stated that these vectors only need a few iterations (around 100), and then you use them. When I try this in my case, specifically in a ReID vision task, I see that, even with the same hyperparameters used during training, the accuracy increases slowly. How can I determine if my vectors are correctly learning the routing paths? Also, if you don’t mind, could you provide me with your contact information? I’d love to discuss ideas further. Thank you once again!

Hey, for training gates, there's no concrete objective we used to measure if the gates are trained properly. In our paper, we did 10% of training steps for experts (which is 100 steps) with all hyper parameters same as expert's training. The output of sigmoid gates start from 0.5, so initial loss should be around same value as the end of expert's training and make sure to double check that loss doesn't go much higher during gate training (if so, try lower lrs). As long as it is around same value, training for a fixed number of steps should give us reasonable gates to use post-hoc.

Thx a lot For explanation Mohammed

this is a simplifed implmentation in case of Vit i use block in routeing vector training and PHATGOOSE block in case of inference im wondring if im missing sth in my code

`class Block(nn.Module):

def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0.,
             drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm,seq_len=129):
    super().__init__()
    self.norm1 = norm_layer(dim)
    self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
    # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
    self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
    self.norm2 = norm_layer(dim)
    mlp_hidden_dim = int(dim * mlp_ratio)
    self.mlp = MlpLoRA(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
    # learn PHATGOOSE gate 

    self.learn_input_gate = True
    if self.learn_input_gate:
        self.expert_input_gate = nn.Parameter(torch.zeros(dim))

def forward(self, x, register_hook=False):
    x = x + self.drop_path(self.attn(self.norm1(x), register_hook=register_hook))
    
    if self.learn_input_gate:
        xskip = x 
        x = self.norm2(x)
        
        input_gate_scores = torch.sum(x * self.expert_input_gate,dim=-1)
        input_gate_probs = torch.sigmoid(input_gate_scores)
        input_gate = input_gate_probs
        x = x * input_gate.unsqueeze(-1)

        x = xskip + self.drop_path(self.mlp(x))
    else:
       
        x = x + self.drop_path(self.mlp(self.norm2(x)))
    
    return x

class PHATGOOSEBlock(nn.Module):

def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0.,
             drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm,seq_len=129):
    super().__init__()
    self.norm1 = norm_layer(dim)
    self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
    # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
    self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
    self.norm2 = norm_layer(dim)
    mlp_hidden_dim = int(dim * mlp_ratio)

    # lets create a simple moe with only one expert 
    self.num_experts = [2,2]
    self.mlp_experts = nn.ModuleList([MlpLoRA(in_features=dim, 
                                        hidden_features=mlp_hidden_dim, 
                                        act_layer=act_layer, 
                                        drop=drop,lora_rank=_)
                                    for _ in self.num_experts])
    
    self.output_size = dim
    self.input_size = dim
    self.hidden_size = seq_len
    self.k = 1

    # init gateing 
    self.w_gate = nn.Parameter(torch.zeros(self.input_size, len(self.num_experts)),requires_grad=False)
    self.softmax = nn.Softmax(1)

    


def top_k_gating(self, x):
    
    
    # x = x /  torch.norm(x, dim=-1, keepdim=True) + 1e-6
    # w_gate = self.w_gate /  torch.norm(self.w_gate, dim=-1, keepdim=True) + 1e-6
    
    clean_logits = x @ self.w_gate
    logits = clean_logits


    # calculate topk + 1 that will be needed for the noisy gates
    logits = self.softmax(logits)

    top_logits, top_indices = logits.topk(min(self.k , len(self.num_experts)), dim=1)
    top_k_logits = top_logits[:, :self.k]
    top_k_indices = top_indices[:, :self.k]
    top_k_gates = top_k_logits / (top_k_logits.sum(1, keepdim=True) + 1e-6)  # normalization
    
    zeros = torch.zeros_like(logits, requires_grad=True)
    gates = zeros.scatter(1, top_k_indices, top_k_gates)


    return gates

def forward(self, x, register_hook=False):

    
    # attn first 
    x = x + self.drop_path(self.attn(self.norm1(x), register_hook=register_hook))
    # experts
    x_skip = x 
    x = self.norm2(x) 
    
    bsz, length, emb_size = x.size()
    x = x.reshape(-1, emb_size)
    gates = self.top_k_gating(x)
    dispatcher = SparseDispatcher(len(self.num_experts), gates)
    expert_inputs = dispatcher.dispatch(x)
    gates = dispatcher.expert_to_gates()
    expert_outputs = [self.mlp_experts[i](expert_inputs[i]) for i in range(len(self.num_experts))]
    x = dispatcher.combine(expert_outputs)
    x = x.view(bsz, length, self.input_size)


    
    
    return x

`