This repo implements the common methods of time series prediction, especially deep learning methods in TensorFlow2. It's highly welcomed to contribute if you have better idea, just create a PR. If any question, feel free to open an issue.
ARIMA |
|
---|---|
GBDT |
|
RNN |
|
wavenet |
|
transformer |
|
U-Net |
|
n-beats |
|
GAN |
- Install the library
pip install -r requirements.txt
- Download the data, if necessary
bash ./data/download_passenger.sh
- Train the model, set
custom_model_params
if you want
cd examples
python run_train.py --use_model seq2seq
- Predict new data
python run_test.py