philipperemy/keras-tcn

Parallel training of TCN

mmawwxx opened this issue · 7 comments

Hi TCN fans,

Recently I'm trying to use TCN to achieve some forecasting works, and it works pretty well. I just have one question regarding the TCN training. As mentioned in the introduction document, one advantage of the TCN model is that it can be trained in parallel, because in each layer the kernels are the same and there is no sequential waiting issues like LSTM and RNN. My question is, such parallel training is achieved automatically by the existing code, or should be set up mannually by some additional codes?

Thanks, much appreciate!

btw, I compared the training time of TCN and CNN-LSTM. I set up them to have similar parameter scales. The results show that both of them takes about 5-10min to finish training, and in many cases CNN-LSTM is even faster...So I thought if this is because I didn't set up the parallel computation correctly...

@mmawwxx good point. Some parts were discussed in this issue: #69

Ideally, some computations could be re-used when forecasting step by step. I think that's what you wanted to say. It is not the case at the moment. Every time you call predict(), all the computations are done from scratch. So it's not really efficient for real-time forecasting. But that's how Keras works I guess. Optimizing it would be insanely difficult in my opinion. Probably pytorch would be a better fit for this type of work.

@philipperemy Thanks for the response ! Actually after checking issue #69, I think my case is pretty similar with #69 . I also set the return sequence to false and only keep a single slice in the last layer. As a result, the computation for the rest slices of the last layer is wasted. Regarding this, I have two follow up questions that wish I could get some of your suggestions:
(1) Currently I'm doing a multi-step ahead forecasting (like 24 steps ahead). What I have done is to set the return sequence as false, hoping that all the information in the input sequence can be extracted and compacted to a single slice in the last layer (just like the encoder in #69). After that, I use a MLP to achieve the mapping from this single slice to the 24-length forecasts, i.e. Dense(24). So basically I'm not repeating the same model for 24 times, but directly achieve 24-step ahead forecasting once for all. From the methodology wise, do you think this is a reasonable way of using TCN to do multi-step forecasting? Or I should set the return sequence as True?
(2) If I set return sequence as false and only keep a single slice in the last layer, is there a way to eliminate the calculation for the rest of the slices in the last layer? This should be able to allievate a lot of computation time

That's what I was expecting. My points:

  • In theory you should just learn P(X_t+1 | X_0,...,X_t) and you could call the model 24 times by re-using the predictions as previous inputs. The problem is that the errors will add. And 24 times is not trivial.
  • Your methodology looks good to me. You have some sort of seq-to-seq architecture. And I don't see anything wrong here.
  • Like you said a Dense(24) will have you modeling P(X_t+1,... X_t+24 | X_0,...,X_t). So you are generating from the joint distribution. And not the conditional distributions (my point 1). One disadvantage of that is that you assume that P(X_t+1) and P(X_t+2) are independent. Said differently, if you permute X_t+2 and Xt+1 in your vector, the problem would be the same.
  • You can also have a look at RepeatVector(24) and TimeDistributed(Dense)).
  • Using return_sequences=False is the right way to me. Because the size of your inputs (something different than 24) is different than the size of your outputs (24). And also P(X_t+1) would be conditioned on P(X_0). P(X_t+2) on P(X_0, X_1). And that's not really what you want. You want to input the full sequence first and then unroll on 24 values.
  • Unfortunately, at the moment it's not really possible. So yeah you are wasting a lot of resources for sure, especially on very long sequences but I don't have anything to offer you at the moment on that ;)

I am 99% sure on what I'm discussing here but I might be wrong especially on the stat part. So do not hesitate to question it.

@mmawwxx my pleasure!