bentrevett/pytorch-sentiment-analysis

6 - Transformers for Sentiment Analysis Model class typo

ChrisProgramming2018 opened this issue · 1 comments

I think in the forward function of the Model class
`
def forward(self, text):

    #text = [batch size, sent len]
            
    with torch.no_grad():
        embedded = bert(text)[0]
            
    #embedded = [batch size, sent len, emb dim]
    
    _, hidden = self.rnn(embedded)

`
bert(text)[0] needs to be replaced with self.bert(text)[0]

You're correct, will change that now.