SkipGramNegativeSampling
Yet another but quite general negative sampling loss implemented in PyTorch. Corpus reference: dl4j.
It can be used with any embedding scheme! Pretty fast, I bet.
V = len(vocab)
word2vec = Word2Vec(V=V)
sgns = SGNS(V=V, embedding=word2vec, batch_size=128, window_size=4, n_negatives=5)
for batch, (iword, owords) in enumerate(dataloader):
loss = sgns(iword, owords)
optimizer.zero_grad()
loss.backward()
optimizer.step()