performance issue with the sampling method
dugu9sword opened this issue · 0 comments
dugu9sword commented
Hello, I found that the sampling method (line 99) in nlpaug/nlpaug/model/lang_models/language_models.py
has a poor performance.
top_n_ids = torch.multinomial(probas, num_samples=n, replacement=False).tolist()
torch.multinomial
is rather slow, see pytorch/pytorch#11931
After changing it to numpy, the speed can accelerate a lot.
Speed up: 3.221s/call -> 5e-6s/call
top_n_ids = np.random.choice(probas.size(0), n, False, probas.cpu().numpy()).tolist()