The kernel appears to have died at the `CatBoost` fitting stage if incorrect `gpu_id` passed
andrewsonin opened this issue · 2 comments
If I have only two available GPUs, I can still pass gpu_id='2'
to the TabularAutoML
without getting an error.
nvidia-smi
output:
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 465.19.01 Driver Version: 465.19.01 CUDA Version: 11.3 |
|-------------------------------+----------------------+----------------------+
| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|===============================+======================+======================|
| 0 NVIDIA Tesla V1... Off | 00000000:3B:00.0 Off | 0 |
| N/A 44C P0 36W / 250W | 1581MiB / 32510MiB | 13% Default |
| | | N/A |
+-------------------------------+----------------------+----------------------+
| 1 NVIDIA Tesla V1... Off | 00000000:86:00.0 Off | 0 |
| N/A 47C P0 47W / 250W | 1249MiB / 32510MiB | 100% Default |
| | | N/A |
+-------------------------------+----------------------+----------------------+
Code:
automl = TabularAutoML(
task=Task('multiclass', ...),
...,
general_params={'use_algos': [['linear_l2', 'cb', 'cb_tuned']]},
cb_params={'default_params': {'thread_count': 48}},
gpu_ids='2',
verbose=3
)
oof_predictions = automl.fit_predict(
train,
roles=...
)
The kernel appears to have died at the fit_predict
stage when it tries to train the CatBoost
model.
Changing gpu_ids='2'
to gpu_ids='1'
solves the problem.
Please add the corresponding check to the __init__
of the TabularAutoML
.
Hi, @andrewsonin. Thank you for your interest in our library.
We are planning to add a separate setup for NLP, CV and reporting (may be others) blocks. However, some will not have PyTorch from which one could easily snatch GPU IDs (as we do now, if you don't pass gpu_ids
). This seems to be solved through parsing aforenamed nvidia-smi
. I think that along with the pull request for a separate installation, we will add the exception you proposed.
Stale issue message