/StockFormer

PyTorch implementation for Paper "StockFormer: Learning Hybrid Trading Machines with Predictive Coding".

Primary LanguagePython

StockFormer (IJCAI'23)

Code repository for this paper:
StockFormer: Learning Hybrid Trading Machines with Predictive Coding.
Siyu Gao, Yunbo Wang, Xiaokang Yang

Preparation

Installation

git clone https://github.com/gsyyysg/StockFormer.git
cd StockFormer
pip install -r requirements.txt

Dataset

Downloaded from YahooFinance

Experiment

Data

dir: 'data/CSI/'

Code

dir:'code/'

1st stage:Representation Learning

1)Relational state inference module training:

cd code/Transformer/script
sh train_mae.sh

2)Long-term state inference module training:

cd code/Transformer/script
sh train_pred_long.sh
  1. Short-term state inference module training:
cd code/Transformer/script
sh train_pred_short.sh
  1. Select the best model of three state inference modules from 'code/Transformer/checkpoints/' according to their performance on validation set and add them to 'code/Transformer/pretrained/'

OR directly use the model which have been pretrained in advance by us (dir:'code/Transformer/pretrained/csi/ ')

2nd stage:Policy Learning

  1. train SAC model (three state inference module's path can be changed in train_rl.py file)
python train_rl.py
  1. get prediction result on test set from 'code/results/df_print/'

Citation

If you find our work helps, please cite our paper.

@inproceedings{gaostockformer,
  title={StockFormer: Learning Hybrid Trading Machines with Predictive Coding},
  author={Gao, Siyu and Wang, Yunbo and Yang, Xiaokang},
  booktitle={IJCAI},
  year={2023}
}

Acknowledgements

This codebase is based on FinRL.