ARBML/klaam

Improving the classification model

zaidalyafeai opened this issue · 1 comments

The classification model needs improvement. The accuracy on the test set is around 62% on the five classes. Here is the model used

class Wav2Vec2ClassificationModel(Wav2Vec2PreTrainedModel):
    def __init__(self, config):
        super().__init__(config)
        
        self.wav2vec2 = Wav2Vec2Model(config)
        
        self.inner_dim = 128
        self.feature_size = 999
        
        self.tanh = nn.Tanh()
        self.linear1 = nn.Linear(1024, self.inner_dim)
        self.linear2 = nn.Linear(self.inner_dim*self.feature_size, 5)
        self.init_weights()
        
    def freeze_feature_extractor(self):
        self.wav2vec2.feature_extractor._freeze_parameters()

    def forward(
        self,
        input_values,
        attention_mask=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
        labels=None
    ):
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
        
        outputs = self.wav2vec2(
            input_values,
            attention_mask=attention_mask,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )
        x = self.linear1(outputs[0]) 
        x = self.tanh(x)
        x = self.linear2(x.view(-1, self.inner_dim*self.feature_size))
        return {'logits':x}

Model accuracy raised to 83%.