Torchscript Trace slower with C++ runtime environment.
sukuya opened this issue · 3 comments
sukuya commented
I traced the BERT model from PyTorchTransformers library and getting the following results for 10 iterations.
a) Using Python runtime for running the forward: 979,292 µs
import time
model = torch.jit.load('models_backup/2_2.pt')
x = torch.randint(2000, (1, 14), dtype=torch.long, device='cpu')
start = time.time()
for i in range(10):
model(x)
end = time.time()
print((end - start)*1000000, "µs")
b) Using C++ runtime for running the forward: 3,333,758 µs which is almost 3x of what Python
torch::Tensor x = torch::randint(index_max, {1, inputsize}, torch::dtype(torch::kInt64).device(torch::kCPU));
input.push_back(x);
#endif
// Execute the model and turn its output into a tensor.
auto outputs = module->forward(input).toTuple();
auto start = chrono::steady_clock::now();
for (int16_t i = 0; i<10; ++i)
{
outputs = module->forward(input).toTuple();
}
auto end = chrono::steady_clock::now();
cout << "Elapsed time in microseconds : "
<< chrono::duration_cast<chrono::microseconds>(end - start).count()
<< " µs" << endl;
@thomwolf any suggestions on what am I missing ?
Meteorix commented
2 possible reasons:
- the first time you run
forward
will do some preheating work, maybe you should exclude the first run. - try exclude
toTuple
According to my experience, jit with python or c++ will cost almost the same time.
sukuya commented
@Meteorix Forward is called once before the loop, are you talking about something else.
Excluding toTuple
doesn't help.
stale commented
This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions.