myGboosting
Простая реализация градиентного бустинга. Работает с задачами регрессии, использует метрику MSE.
Использование
Утилита предполагает вызов из командной строки. В качестве обучающей и тестовой выборки используются стандартные файлы csv. Утилита имеет 2 режима работы: обучение (fit) и применение (predict).
Сборка проекта
перейти в каталог src, создать в нем каталог build, перейти в него и выполнить следующие команды:
cmake ..
make -j4
при этом на компьютере должны быть установлена библиотека Open MP.
Обучение
Формат запуска:
myGboosting fit <file path> [optional parameters]
<file path>
- путь к csv файлу, содержащему обучающую выборку
Параметры:
Параметр | Описание | Значение по умолчанию |
---|---|---|
column_names | путь к файлу, содержащему названия столбцов обучающей выборки | |
model-path | имя файла, в который будет сохранена обученная модель | |
output-path | имя файла, в который будут сохранены прогнозы обученной модели на обучающей выборке | |
target | номер колонки, которая содержит значение целевой переменной | последняя колонка(-1) |
nthread | число параллельных потоков, используемых для обучения | 1 |
delimiter | разделитель, используемый в csv-файлах | , |
has-header | имеет ли входной csv-файл заголовок с названиями столбцов | false |
iterations | максимальное количество деревьев в модели | 100 |
learning-rate | темп обучения модели | 1.0 |
depth | глубина решающего дерева | 6 |
max_bins | количество сплитов в гистограмме для числовых признаков (от 1 до 255) | 10 |
verbose | степень подробности выводимой в консоль информации | 0 |
sample_rate | вероятность сэмплинга строк для каждого дерева (какую часть датасета использовать) | 0.66 |
min_leaf_count | минимальное количество объектов в листовой вершине | 1 |
Применение
Формат запуска:
myGboosting predict <file path> [optional parameters]
<file path> - путь к csv файлу, содержащему тестовую выборку
Параметры:
Параметр | Описание | Значение по умолчанию |
---|---|---|
column_names | путь к файлу, содержащему названия столбцов тестовой выборки | |
model-path | имя файла, из которого будет считана обученная модель | |
output-path | имя файла, в который будут сохранены прогнозы модели на тестовой выборке | |
delimiter | разделитель, используемый в csv-файлах | , |
has-header | имеет ли входной csv-файл заголовок с названиями столбцов | false |
has-target | имеет ли входной csv-файл колонку со значениями целевой переменной | false |
target | номер колонки, которая содержит значение целевой переменной | последняя колонка(-1) |
verbose | степень подробности выводимой в консоль информации | 0 |
Архитектура
- Используются Oblivious Decision trees и гистограммы признаков
- Параллелизация при обучении происходит при выборе оптимального сплита c помощью библиотеки Open MP
- для сохранения и загрузки моделей используется библиотека Protobuf
Результаты
Модель проверялась на наборе данных Higgs, была взята train выборка (250000 сэмплов) и поделена на train (20000) и тест (50000). Поскольку последняя колонка - это бинарная классификация, то предсказывалась предпоследняя (вес частицы).
Параметры запуска:
LightGBM
time ./lightgbm objective=mse data=../../myGboosting/testing/datasets/Higgs/train.csv
num_threads=1 num_iterations=400 max_bin=255 bagging_fraction=0.5 feature_fraction=1.0 bagging_freq=1
num_leaves=64 learning_rate=0.5 label=31 min_data_in_leaf=1
myGboosting
time ./myGboosting fit ../testing/datasets/Higgs/train.csv --output=model.pb
--iterations=400 --depth=6 --learning-rate=0.5 --sample-rate=0.5 --max_bins=255
--nthreads=1 --verbose=1
Результаты производительности на 1 потоке
Модель проверялась на ноутбуке Macbook Pro 15 2015. Поэтому в наличии есть 4 реальных ядра и 8 виртуальных.
Решение | Depth | Row sampling | Кол-во деревьев | Learning Rate | Время | MSE Train | MSE Test |
---|---|---|---|---|---|---|---|
LightGBM | 3 | 0.5 | 400 | 0.5 | 8.185s | 1.20136 | 1.39434 |
myGboosting | 3 | 0.5 | 400 | 0.5 | 7.420s | 1.26377 | 1.3244 |
LightGBM | 6 | 0.5 | 400 | 0.5 | 13.352s | 0.6469 | 1.88662 |
myGboosting | 6 | 0.5 | 400 | 0.5 | 12.062s | 1.11156 | 1.44719 |
LightGBM | 9 | 0.5 | 400 | 0.5 | 62.043s | 0.02141 | 2.7639 |
myGboosting | 9 | 0.5 | 400 | 0.5 | 30.145s | 0.742787 | 1.88279 |
LightGBM | 4 | 0.5 | 800 | 0.1 | 16.285s | 1.11651 | 1.30245 |
myGboosting | 4 | 0.5 | 800 | 0.1 | 16.161s | 1.24526 | 1.29291 |
LightGBM | 4 | 0.5 | 4000 | 0.02 | 78.118s | 1.1006 | 1.28318 |
myGboosting | 4 | 0.5 | 4000 | 0.04 | 76.491s | 1.1894 | 1.28369 |
LightGBM | 4 | 0.5 | 200 | 0.7 | 5.214s | 1.2102 | 1.55602 |
myGboosting | 4 | 0.5 | 200 | 0.7 | 4.784s | 1.27188 | 1.35914 |
LightGBM | 6 | 0.7 | 400 | 0.5 | 13.643s | 0.5028 | 1.7779 |
myGboosting | 6 | 0.7 | 400 | 0.5 | 13.974s | 1.06907 | 1.4014 |
LightGBM | 6 | 1.0 | 400 | 0.5 | 13.483s | 0.4661 | 1.61041 |
myGboosting | 6 | 1.0 | 400 | 0.5 | 16.315s | 1.05107 | 1.35413 |
LightGBM | 9 | 1.0 | 400 | 0.2 | 55.146s | 0.06689 | 1.42377 |
myGboosting | 9 | 1.0 | 400 | 0.2 | 35.697s | 0.85707 | 1.3268 |
Из таблицы видно, что наше решение обгоняет LightGBM в подавляющем большинстве случаев. Странным образом на LightGBM влияет параметр Row sampling. При его уменьшении скорость должна расти, а у него она падает.
Также видно, что темп обучения у ODT отстает от темпа обучения обычных деревьев. При этом качество на тестовой выборке на данных экспериментах у нашей модели лучше, но это вызвано тем, что модель LightGBM успевает переобучиться за заданное число итераций.
Результаты производительности в многопоточном режиме (время)
Решение | 1 поток | 2 потока | 4 потока | 6 потоков | 8 потоков |
---|---|---|---|---|---|
LightGBM | 13.352s | 10.404s | 8.727s | 9.070s | 9.766s |
myGboosting | 12.062s | 8.915s | 7.604s | 7.304s | 7.313s |
Параметры запуска такие же, как приведены в начале прошлого пункта (глубина 6, row sampling 0.5, 400 деревьев)
Видно, что после 4 потоков производительность почти не растет. Наше решение обгоняет LightGBM с ростом числа потоков.
Результаты измерений качества на отложенной выборке (test)
Попробовав достаточное количество различных комбинаций параметров, мы увидели, что MSE на отложенной выборке (test) не опускалось ниже 1.283, и LightGBM, и myGboosting приближались к этой границе при росте числа деревьев до 4000 и маленьком learning rate (см. колонку test MSE в таблице выше). Из этого мы можем сделать вывод, что по качеству наше решение не уступает LightGBM, но learning rate для нашего решения нужно выставлять в 2 раза больше, чем для LightGBM.
Ход наших экспериментов
- Мы реализовали базовую версию решающего дерева, которая перебирала все возможные сплиты.
- Стало понятно, что это работает очень медленно, и мы перешли к гистограммам
- Далее, возникла идея перенести гистограммы на уровень фичей и добавить работу с категориальными признаками. Для этого мы реализовали бинаризацию численных признаков и one hot кодирование для категориальных признаков
- В ходе замеров производительности выяснилось, что данный подход работает медленно и требует большого количества памяти на больших датасетах (Higgs).
- После этого мы перешли к реализации Oblivious Decision Trees (ODT). Мы выбрали вариант ODT, при котором мы выбираем одно значение сплита для всех узлов конкретного уровня, это позволяет хранить дерево как список сплитов (id_признака, № корзины) и список значений в листовых вершинах (2^<Глубина дерева> вершин)
- Эксперименты с ODT показали, что MSE уменьшается гораздо медленнее, чем на обычных деревьях. После отладки выяснилось, что наше разбиение на корзины работает не оптимально, и мы взяли его реализацию из LightGBM.
- Кроме того, на MSE негативно влиял параметр min_child_weight (минимальное кол-во объектов в листах). Выяснилось, что для ODT оптимальнее строить полные бинарные деревья, независимо от числа объектов в листьях (min_child_weight=1).
- Теперь MSE стало вести себя гораздо лучше.
- Также одной из особенностей ODT является то, что фичи в одном и том же дереве не повторяются. Мы пробовали отключать это правило и на небольших датасетах это давало лучшее уменьшение MSE, но на больших датасетах MSE уменьшается одинаково, а с включенным правилом алгоритм работает быстрее.
- Далее мы перешли к реализации нескольких потоков. После профилирования с помощью Valgrind выяснилось, что основное время работы алгоритма занимает построение гистограмм. Поэтому для приемлемой производительности многопоточного алгоритма оказалось достаточно распараллелить поиск оптимального сплита, в рамках которого у нас строятся гистограммы.
- Для реализации параллельных потоков была выбрана библиотека OpenMP, т.к. LightGBM и XGBoost используют ее же. Параллельность была реализована с помощью параллельного цикла for с критической секцией по выбору максимального Gain.
- При сравнении производительности с LightGBM была замечена странная особенность: изменение параметра row sampling (сэмплирование датасета для конкретного дерева) не приводит к увеличению производительности LightGBM. При этом наше решение увеличивает свою производительность (что логично, т.к. дереву нужно учиться на меньшем объеме данных).
- При сборке на других машинах возникли проблемы со сборкой библиотеки Protobuf, поэтому пришлось написать кастомную сериализацию моделей.