Can the TCN module be made faster?
bjourne opened this issue · 8 comments
I'm using your TCN module for a language modeling task. My code follows the structure of your char_cnn code. It works but the performance is very bad compared to an LSTM network. Each epoch with the TCN network takes about 10 times longer. Do you know if the performance can be improved? Here is the forward method from the TCN class:
def forward(self, x):
emb = self.drop(self.encoder(x))
y = self.tcn(emb.transpose(1, 2))
o = self.decoder(y.transpose(1, 2))
return o.contiguous()
Perhaps it is the transpose calls that is making the code slow?
Thanks for your interest in our work!
The longer the sequence, the better the efficiency of TCN when compared to LSTM. This is because TCN still has depth (e.g., 8 layers, and each layer needs to process T tokens in parallel). So one way is to train with longer sequence length.
The other way to speed up TCN is to play with the kernel size and the dilation factor. Obviously, once you have larger receptive field per layer, then you can reduce the number of layers. The transpose operation is indeed not optimal, but I definitely don't think it is responsible for the efficiency problem here.
You can also do profiling to see which part of the model takes the most amount of time: https://pytorch.org/docs/stable/autograd.html#torch.autograd.profiler.profile
Thanks for your response! I'm benchmarking a TCN model that is identical to your model in char_cnn_test.py versus a two-layer RNN model that I created. The models have roughly the same number of parameters:
-------------------------------------------------------------------------
Layer (type) Output Shape Param # Tr. Param #
=========================================================================
Embedding-1 [32, 320, 100] 4,900 4,900
Dropout-2 [32, 320, 100] 0 0
TemporalConvNet-3 [32, 100, 320] 2,217,050 2,217,050
Linear-4 [32, 320, 49] 4,949 4,949
=========================================================================
Total params: 2,226,899
Trainable params: 2,226,899
Non-trainable params: 0
-------------------------------------------------------------------------
and
--------------------------------------------------------------------------------------------------
Layer (type) Output Shape Param # Tr. Param #
==================================================================================================
Embedding-1 [32, 320, 100] 4,900 4,900
Dropout-2 [32, 320, 100] 0 0
LSTM-3 [32, 320, 700], [1, 32, 700], [1, 32, 700] 2,245,600 2,245,600
Linear-4 [10240, 49] 34,349 34,349
==================================================================================================
Total params: 2,284,849
Trainable params: 2,284,849
Non-trainable params: 0
--------------------------------------------------------------------------------------------------
The length of the sequences are 320 and I'm benchmarking on Google collab with the GPU option. I'm using the Penn Tree Bank dataset that you used. Logs for the TCN:
200 / 514 | 44s | 2.7770
400 / 514 | 44s | 2.1416
\-> 1 / 200 - 114s - 4.000 - 1.8228 *
200 / 514 | 44s | 1.7575
400 / 514 | 44s | 1.5874
\-> 2 / 200 - 114s - 4.000 - 1.4341 *
200 / 514 | 44s | 1.4313
400 / 514 | 44s | 1.3583
\-> 3 / 200 - 114s - 4.000 - 1.2771 *
200 / 514 | 44s | 1.2918
400 / 514 | 44s | 1.2519
\-> 4 / 200 - 114s - 4.000 - 1.2037 *
So as you can see each epoch takes 114 seconds. Logs for the RNN:
200 / 514 | 12s | 2.2584
400 / 514 | 12s | 1.8055
\-> 1 / 200 - 33s - 4.000 - 1.5874 *
200 / 514 | 12s | 1.5346
400 / 514 | 12s | 1.4220
\-> 2 / 200 - 33s - 4.000 - 1.3404 *
200 / 514 | 12s | 1.3173
400 / 514 | 12s | 1.2643
\-> 3 / 200 - 33s - 4.000 - 1.2335 *
200 / 514 | 12s | 1.2146
400 / 514 | 12s | 1.1801
\-> 4 / 200 - 33s - 4.000 - 1.1729 *
only 33 seconds per epoch. Here is my benchmark program: https://github.com/bjourne/python3-libs/blob/master/tools/char_lm.py
The profiles for the training loops for the networks follows. For the TCN:
---------------------------------------- --------------- --------------- --------------- --------------- --------------- ---------------
Name Self CPU total % Self CPU total CPU total % CPU total CPU time avg Number of Calls
---------------------------------------- --------------- --------------- --------------- --------------- --------------- ---------------
embedding_dense_backward 93.23% 105.381s 93.23% 105.381s 205.022ms 514
_local_scalar_dense 4.05% 4.583s 4.05% 4.583s 4.458ms 1028
cudnn_convolution_backward 0.36% 408.795ms 0.36% 408.795ms 99.415us 4112
cudnn_convolution 0.29% 324.076ms 0.29% 324.076ms 78.812us 4112
add_ 0.20% 225.748ms 0.20% 225.748ms 8.972us 25162
norm 0.18% 201.492ms 0.18% 201.492ms 15.680us 12850
zero_ 0.15% 174.809ms 0.15% 174.809ms 8.107us 21564
...
UnsafeViewBackward 0.00% 480.047us 0.00% 2.187ms 4.254us 514
is_floating_point 0.00% 394.121us 0.00% 394.121us 0.767us 514
---------------------------------------- --------------- --------------- --------------- --------------- --------------- ---------------
Self CPU time total: 113.090s
and for the RNN:
----------------------------------- --------------- --------------- --------------- --------------- --------------- ---------------
Name Self CPU total % Self CPU total CPU total % CPU total CPU time avg Number of Calls
----------------------------------- --------------- --------------- --------------- --------------- --------------- ---------------
embedding_dense_backward 63.21% 20.797s 63.21% 20.797s 40.461ms 514
_local_scalar_dense 21.49% 7.069s 21.49% 7.069s 6.876ms 1028
_cudnn_rnn_backward 6.97% 2.293s 6.97% 2.293s 4.462ms 514
_cudnn_rnn 6.63% 2.183s 6.63% 2.183s 4.247ms 514
norm 0.22% 74.007ms 0.22% 74.007ms 17.998us 4112
...
CloneBackward 0.00% 262.108us 0.00% 262.108us 0.510us 514
----------------------------------- --------------- --------------- --------------- --------------- --------------- ---------------
Self CPU time total: 32.901s
As you can see, TCN has a lot more (though cheap) addition and normalization operations whereas RNNs make fewer such calls. The most straightforward way to speed up TCN is therefore to reduce its depth or increase the dilation factor. However, generally, when you consider a TCN with a multi-layer LSTM, I don't' think TCN would be that slow.
I have also benchmarks comparing the TCN for Keras with LSTM:s. https://github.com/philipperemy/keras-tcn and in those the TCN wins. So I don't think the problem is inherent with the TCN structure rather something must be wrong with your implementation.
Possibly, yeah, but if you compare keras-tcn with the tcn.py in this repo, they are essentially the same structure which depends on the iterations of padded Conv1d (cf. https://github.com/philipperemy/keras-tcn/blob/master/tcn/tcn.py#L111 and https://github.com/philipperemy/keras-tcn/blob/master/tcn/tcn.py#L145).
The TCN implementation in this repo is a very simple one (as you probably have already noticed) and I can hardly see where it "must be wrong". There is also likely to be a framework-based difference in how these modules are implemented (e.g., different frameworks usually implement RNN modules in different ways). In case you are interested, for example, this page finds that PyTorch's RNNs (e.g., GRUs) are a lot faster than in Keras (see the orange bars); whereas CNNs are about as fast.
Thanks for the insightful & careful benchmarking! I think it's very useful to investigate the most efficient way of implementing and deploying TCNs in general :-)
Have you looked at https://arxiv.org/abs/1611.09482
Although I think the performance increase here occurs when you're predicting an autoregressive process which is quite inefficient with the default conv structure. I've looked at the pytorch LSTM source and it seems to be almost entirely implemented in C++ so you cannot really compare performance directly.
In my mind at least TCN is an RNN (at least when you're modeling an autoregressive process), but it's trained using teacher forcing in the same way that Matlab trains a NARX model in "open loop" but then predicts using "closed loop". You can write a NARX(p) model as an RNN with the state being the p previous inputs and the state update being to shift the previous state vector by one and append the new "previous output".
Similarly you can express a TCN as a series of Conv's, or as a recursive filter on a moving window.
I went with a transformer model instead. For my problem I get just as low loss as with a tcn but it trains much faster.