kathrinse/TabSurvey

If cat_idxs is non-empty, cat_dims must be defined as a list of same length

Closed this issue · 2 comments

Hi Kathrin,

I get the following error for california_housing and covertype for the TabNet model.

python train.py  --config config/covertype.yml --model_name TabNet 
.
Namespace(config='config/covertype.yml', model_name='TabNet', dataset='Covertype', objective='classification', 
use_gpu=True, gpu_ids=[0, 1, 2, 3], data_parallel=True, optimize_hyperparameters=False, n_trials=100, 
direction='minimize', 
num_splits=5, shuffle=True, seed=221, scale=True, target_encode=True, one_hot_encode=False, batch_size=128, 
val_batch_size=256, early_stopping_rounds=20, epochs=1000, logging_period=100, num_features=54, num_classes=7, 
cat_idx=None, cat_dims=None)
.
.
raceback (most recent call last):
  File "/scratch1/dun280/TabSurvey/train.py", line 154, in <module>
    main_once(arguments)
  File "/scratch1/dun280/TabSurvey/train.py", line 138, in main_once
    sc, time = cross_validation(model, X, y, args)
  File "/scratch1/dun280/TabSurvey/train.py", line 41, in cross_validation
    loss_history, val_loss_history = curr_model.fit(X_train, y_train, X_test, y_test)  # X_val, y_val)
  File "/scratch1/dun280/TabSurvey/models/tabnet.py", line 38, in fit
    self.model.fit(X, y, eval_set=[(X_val, y_val)], eval_name=["eval"], eval_metric=self.metric,
  File "/home/dun280/.local/lib/python3.9/site-packages/pytorch_tabnet/abstract_model.py", line 223, in fit
    self._set_network()
  File "/home/dun280/.local/lib/python3.9/site-packages/pytorch_tabnet/abstract_model.py", line 570, in _set_network
    self.network = tab_network.TabNet(
  File "/home/dun280/.local/lib/python3.9/site-packages/pytorch_tabnet/tab_network.py", line 567, in __init__
    self.embedder = EmbeddingGenerator(input_dim, cat_dims, cat_idxs, cat_emb_dim)
  File "/home/dun280/.local/lib/python3.9/site-packages/pytorch_tabnet/tab_network.py", line 809, in __init__
    raise ValueError(msg)
ValueError: If cat_idxs is non-empty, cat_dims must be defined as a list of same length.

Unfortunately, once again, this is also happening for my data

Bye
R

Hey,

this is also caused due to an update of the TabNet implementation. I fixed it now in the code.
You can easily fix it yourself by replace the line self.params["cat_idxs"] = args.cat_idx with self.params["cat_idxs"] = args.cat_idx if args.cat_idx else [] in the models/tabnet.py file in the __init__.py-method.

yes, I can confirm that this is now working as expected.
Bye and thanks
R