makcedward/nlpaug

performance issue with the sampling method

dugu9sword opened this issue · 0 comments

Hello, I found that the sampling method (line 99) in nlpaug/nlpaug/model/lang_models/language_models.py has a poor performance.

top_n_ids = torch.multinomial(probas, num_samples=n, replacement=False).tolist()

torch.multinomial is rather slow, see pytorch/pytorch#11931

After changing it to numpy, the speed can accelerate a lot.
Speed up: 3.221s/call -> 5e-6s/call

top_n_ids = np.random.choice(probas.size(0), n, False, probas.cpu().numpy()).tolist()