If cat_idxs is non-empty, cat_dims must be defined as a list of same length
Closed this issue · 2 comments
parsifal9 commented
Hi Kathrin,
I get the following error for california_housing and covertype for the TabNet model.
python train.py --config config/covertype.yml --model_name TabNet
.
Namespace(config='config/covertype.yml', model_name='TabNet', dataset='Covertype', objective='classification',
use_gpu=True, gpu_ids=[0, 1, 2, 3], data_parallel=True, optimize_hyperparameters=False, n_trials=100,
direction='minimize',
num_splits=5, shuffle=True, seed=221, scale=True, target_encode=True, one_hot_encode=False, batch_size=128,
val_batch_size=256, early_stopping_rounds=20, epochs=1000, logging_period=100, num_features=54, num_classes=7,
cat_idx=None, cat_dims=None)
.
.
raceback (most recent call last):
File "/scratch1/dun280/TabSurvey/train.py", line 154, in <module>
main_once(arguments)
File "/scratch1/dun280/TabSurvey/train.py", line 138, in main_once
sc, time = cross_validation(model, X, y, args)
File "/scratch1/dun280/TabSurvey/train.py", line 41, in cross_validation
loss_history, val_loss_history = curr_model.fit(X_train, y_train, X_test, y_test) # X_val, y_val)
File "/scratch1/dun280/TabSurvey/models/tabnet.py", line 38, in fit
self.model.fit(X, y, eval_set=[(X_val, y_val)], eval_name=["eval"], eval_metric=self.metric,
File "/home/dun280/.local/lib/python3.9/site-packages/pytorch_tabnet/abstract_model.py", line 223, in fit
self._set_network()
File "/home/dun280/.local/lib/python3.9/site-packages/pytorch_tabnet/abstract_model.py", line 570, in _set_network
self.network = tab_network.TabNet(
File "/home/dun280/.local/lib/python3.9/site-packages/pytorch_tabnet/tab_network.py", line 567, in __init__
self.embedder = EmbeddingGenerator(input_dim, cat_dims, cat_idxs, cat_emb_dim)
File "/home/dun280/.local/lib/python3.9/site-packages/pytorch_tabnet/tab_network.py", line 809, in __init__
raise ValueError(msg)
ValueError: If cat_idxs is non-empty, cat_dims must be defined as a list of same length.
Unfortunately, once again, this is also happening for my data
Bye
R
kathrinse commented
Hey,
this is also caused due to an update of the TabNet implementation. I fixed it now in the code.
You can easily fix it yourself by replace the line self.params["cat_idxs"] = args.cat_idx
with self.params["cat_idxs"] = args.cat_idx if args.cat_idx else []
in the models/tabnet.py
file in the __init__.py
-method.
parsifal9 commented
yes, I can confirm that this is now working as expected.
Bye and thanks
R