At first, we need run script ./scripts/data/mnist_preprocess.py
to get pre-processed data
python ./scripts/data/mnist_preprocess.py --data_size 28
cache: root for saving data, the 'test_data.txt', 'train_data.txt', 'test_label.txt', 'train_label.txt' is saved in this directory.
scripts: total scripts
更改模型类别可以通过参数model_type进行更改, model_type=transformer or model_type=linear
training and test run the following command
python main.py --lr 0.001 --batch_size 256 --n_epochs 10 --data_split_dim 4 --data_dimension 8 --n_heads 1 --model_type transformer --gama_scale 0.001