QData/spacetimeformer

Getting error on PEMS-BAY dataset

kadattack opened this issue · 3 comments

I downloaded the pems-bay.h5 file from https://zenodo.org/record/4263971 and used https://github.com/liyaguang/DCRNN/blob/master/scripts/generate_training_data.py to generate test.npz, train.npz, val.npz files in ./data/pems_bay/.

When I run the command from the README.md python train.py spacetimeformer pems-bay --batch_size 32 --warmup_steps 1000 --d_model 200 --d_ff 700 --enc_layers 5 --dec_layers 6 --dropout_emb .1 --dropout_ff .3 --run_name pems-bay-spatiotemporal --base_lr 1e-3 --l2_coeff 1e-3 --loss mae --data_path ./data/pems_bay/ --d_qk 30 --d_v 30 --n_heads 10 --patience 10 --decay_factor .8

I get the following error

Traceback (most recent call last):
File "/spacetimeformer-main/spacetimeformer/train.py", line 854, in
main(args)
File "/spacetimeformer-main/spacetimeformer/train.py", line 758, in main
) = create_dset(args)
File "/spacetimeformer-main/spacetimeformer/train.py", line 394, in create_dset
data = stf.data.metr_la.METR_LA_Data(config.data_path)
File "/spacetimeformer-main/spacetimeformer/data/metr_la/metr_la.py", line 43, in init
x_c_train, y_c_train = self._split_set(context_train)
File "/spacetimeformer-main/spacetimeformer/data/metr_la/metr_la.py", line 21, in _split_set
time = 2.0 * x[:, :, 0] - 1.0
IndexError: too many indices for array: array is 2-dimensional, but 3 were indexed

我也是一样的问题,请问该怎么解决

我从https://zenodo.org/record/4263971下载了pems-bay.h5 文件并使用https://github.com/liyaguang/DCRNN/blob/master/scripts/generate_training_data.py生成`test.npz`,,文件./data/pems_bay/.`train.npz``val.npz`

当我从 README.md 运行命令时 python train.py spacetimeformer pems-bay --batch_size 32 --warmup_steps 1000 --d_model 200 --d_ff 700 --enc_layers 5 --dec_layers 6 --dropout_emb .1 --dropout_ff .3 --run_name pems-bay-spatiotemporal --base_lr 1e-3 --l2_coeff 1e-3 --loss mae --data_path ./data/pems_bay/ --d_qk 30 --d_v 30 --n_heads 10 --patience 10 --decay_factor .8

我收到以下错误

回溯(最近一次通话最后一次):
文件“/spacetimeformer-main/spacetimeformer/train.py”,第 854 行,在
main(args)
文件“/spacetimeformer-main/spacetimeformer/train.py”,第 758 行,在 main 中
) = create_dset(args)
文件“/spacetimeformer-main/spacetimeformer/train.py”,第 394 行,在 create_dset
数据 = stf.data.metr_la.METR_LA_Data(config.data_path)
文件“/spacetimeformer-main/spacetimeformer/data/metr_la /metr_la.py”,第 43 行,在init
x_c_train, y_c_train = self._split_set(context_train)
文件“/spacetimeformer-main/spacetimeformer/data/metr_la/metr_la.py”,第 21 行,在 _split_set
time = 2.0 * x[ :, :, 0] - 1.0
IndexError:数组索引太多:数组是二维的,但有 3 个被索引

请问你解决了没

Set add_day_in_week to True in https://github.com/liyaguang/DCRNN/blob/master/scripts/generate_training_data.py#L72
This code expects those features to be present.