locuslab/TCN

Can the TCN module be made faster?

bjourne opened this issue · 8 comments

I'm using your TCN module for a language modeling task. My code follows the structure of your char_cnn code. It works but the performance is very bad compared to an LSTM network. Each epoch with the TCN network takes about 10 times longer. Do you know if the performance can be improved? Here is the forward method from the TCN class:

    def forward(self, x):
        emb = self.drop(self.encoder(x))
        y = self.tcn(emb.transpose(1, 2))
        o = self.decoder(y.transpose(1, 2))
        return o.contiguous()

Perhaps it is the transpose calls that is making the code slow?

Thanks for your interest in our work!

The longer the sequence, the better the efficiency of TCN when compared to LSTM. This is because TCN still has depth (e.g., 8 layers, and each layer needs to process T tokens in parallel). So one way is to train with longer sequence length.

The other way to speed up TCN is to play with the kernel size and the dilation factor. Obviously, once you have larger receptive field per layer, then you can reduce the number of layers. The transpose operation is indeed not optimal, but I definitely don't think it is responsible for the efficiency problem here.

You can also do profiling to see which part of the model takes the most amount of time: https://pytorch.org/docs/stable/autograd.html#torch.autograd.profiler.profile

Thanks for your response! I'm benchmarking a TCN model that is identical to your model in char_cnn_test.py versus a two-layer RNN model that I created. The models have roughly the same number of parameters:

-------------------------------------------------------------------------
        Layer (type)        Output Shape         Param #     Tr. Param #
=========================================================================
         Embedding-1      [32, 320, 100]           4,900           4,900
           Dropout-2      [32, 320, 100]               0               0
   TemporalConvNet-3      [32, 100, 320]       2,217,050       2,217,050
            Linear-4       [32, 320, 49]           4,949           4,949
=========================================================================
Total params: 2,226,899
Trainable params: 2,226,899
Non-trainable params: 0
-------------------------------------------------------------------------

and

--------------------------------------------------------------------------------------------------
      Layer (type)                                   Output Shape         Param #     Tr. Param #
==================================================================================================
       Embedding-1                                 [32, 320, 100]           4,900           4,900
         Dropout-2                                 [32, 320, 100]               0               0
            LSTM-3     [32, 320, 700], [1, 32, 700], [1, 32, 700]       2,245,600       2,245,600
          Linear-4                                    [10240, 49]          34,349          34,349
==================================================================================================
Total params: 2,284,849
Trainable params: 2,284,849
Non-trainable params: 0
--------------------------------------------------------------------------------------------------

The length of the sequences are 320 and I'm benchmarking on Google collab with the GPU option. I'm using the Penn Tree Bank dataset that you used. Logs for the TCN:

  200 /   514 |  44s | 2.7770
  400 /   514 |  44s | 2.1416
\->  1 / 200 - 114s - 4.000 - 1.8228 *
  200 /   514 |  44s | 1.7575
  400 /   514 |  44s | 1.5874
\->  2 / 200 - 114s - 4.000 - 1.4341 *
  200 /   514 |  44s | 1.4313
  400 /   514 |  44s | 1.3583
\->  3 / 200 - 114s - 4.000 - 1.2771 *
  200 /   514 |  44s | 1.2918
  400 /   514 |  44s | 1.2519
\->  4 / 200 - 114s - 4.000 - 1.2037 *

So as you can see each epoch takes 114 seconds. Logs for the RNN:

  200 /   514 |  12s | 2.2584
  400 /   514 |  12s | 1.8055
\->  1 / 200 -  33s - 4.000 - 1.5874 *
  200 /   514 |  12s | 1.5346
  400 /   514 |  12s | 1.4220
\->  2 / 200 -  33s - 4.000 - 1.3404 *
  200 /   514 |  12s | 1.3173
  400 /   514 |  12s | 1.2643
\->  3 / 200 -  33s - 4.000 - 1.2335 *
  200 /   514 |  12s | 1.2146
  400 /   514 |  12s | 1.1801
\->  4 / 200 -  33s - 4.000 - 1.1729 *

only 33 seconds per epoch. Here is my benchmark program: https://github.com/bjourne/python3-libs/blob/master/tools/char_lm.py

The profiles for the training loops for the networks follows. For the TCN:

----------------------------------------  ---------------  ---------------  ---------------  ---------------  ---------------  ---------------
Name                                      Self CPU total %  Self CPU total   CPU total %      CPU total        CPU time avg     Number of Calls
----------------------------------------  ---------------  ---------------  ---------------  ---------------  ---------------  ---------------
embedding_dense_backward                  93.23%           105.381s         93.23%           105.381s         205.022ms        514
_local_scalar_dense                       4.05%            4.583s           4.05%            4.583s           4.458ms          1028
cudnn_convolution_backward                0.36%            408.795ms        0.36%            408.795ms        99.415us         4112
cudnn_convolution                         0.29%            324.076ms        0.29%            324.076ms        78.812us         4112
add_                                      0.20%            225.748ms        0.20%            225.748ms        8.972us          25162
norm                                      0.18%            201.492ms        0.18%            201.492ms        15.680us         12850
zero_                                     0.15%            174.809ms        0.15%            174.809ms        8.107us          21564
...
UnsafeViewBackward                        0.00%            480.047us        0.00%            2.187ms          4.254us          514              
is_floating_point                         0.00%            394.121us        0.00%            394.121us        0.767us          514              
----------------------------------------  ---------------  ---------------  ---------------  ---------------  ---------------  ---------------  
Self CPU time total: 113.090s

and for the RNN:

-----------------------------------  ---------------  ---------------  ---------------  ---------------  ---------------  ---------------  
Name                                 Self CPU total %  Self CPU total   CPU total %      CPU total        CPU time avg     Number of Calls  
-----------------------------------  ---------------  ---------------  ---------------  ---------------  ---------------  ---------------  
embedding_dense_backward             63.21%           20.797s          63.21%           20.797s          40.461ms         514              
_local_scalar_dense                  21.49%           7.069s           21.49%           7.069s           6.876ms          1028             
_cudnn_rnn_backward                  6.97%            2.293s           6.97%            2.293s           4.462ms          514              
_cudnn_rnn                           6.63%            2.183s           6.63%            2.183s           4.247ms          514              
norm                                 0.22%            74.007ms         0.22%            74.007ms         17.998us         4112
...
CloneBackward                        0.00%            262.108us        0.00%            262.108us        0.510us          514              
-----------------------------------  ---------------  ---------------  ---------------  ---------------  ---------------  ---------------  
Self CPU time total: 32.901s

As you can see, TCN has a lot more (though cheap) addition and normalization operations whereas RNNs make fewer such calls. The most straightforward way to speed up TCN is therefore to reduce its depth or increase the dilation factor. However, generally, when you consider a TCN with a multi-layer LSTM, I don't' think TCN would be that slow.

I have also benchmarks comparing the TCN for Keras with LSTM:s. https://github.com/philipperemy/keras-tcn and in those the TCN wins. So I don't think the problem is inherent with the TCN structure rather something must be wrong with your implementation.

Possibly, yeah, but if you compare keras-tcn with the tcn.py in this repo, they are essentially the same structure which depends on the iterations of padded Conv1d (cf. https://github.com/philipperemy/keras-tcn/blob/master/tcn/tcn.py#L111 and https://github.com/philipperemy/keras-tcn/blob/master/tcn/tcn.py#L145).

The TCN implementation in this repo is a very simple one (as you probably have already noticed) and I can hardly see where it "must be wrong". There is also likely to be a framework-based difference in how these modules are implemented (e.g., different frameworks usually implement RNN modules in different ways). In case you are interested, for example, this page finds that PyTorch's RNNs (e.g., GRUs) are a lot faster than in Keras (see the orange bars); whereas CNNs are about as fast.

Thanks for the insightful & careful benchmarking! I think it's very useful to investigate the most efficient way of implementing and deploying TCNs in general :-)

Have you looked at https://arxiv.org/abs/1611.09482

Although I think the performance increase here occurs when you're predicting an autoregressive process which is quite inefficient with the default conv structure. I've looked at the pytorch LSTM source and it seems to be almost entirely implemented in C++ so you cannot really compare performance directly.

In my mind at least TCN is an RNN (at least when you're modeling an autoregressive process), but it's trained using teacher forcing in the same way that Matlab trains a NARX model in "open loop" but then predicts using "closed loop". You can write a NARX(p) model as an RNN with the state being the p previous inputs and the state update being to shift the previous state vector by one and append the new "previous output".

Similarly you can express a TCN as a series of Conv's, or as a recursive filter on a moving window.

I went with a transformer model instead. For my problem I get just as low loss as with a tcn but it trains much faster.