/myGboosting

Gradient boosting implementation

Primary LanguageC++

myGboosting

Простая реализация градиентного бустинга. Работает с задачами регрессии, использует метрику MSE.

Использование

Утилита предполагает вызов из командной строки. В качестве обучающей и тестовой выборки используются стандартные файлы csv. Утилита имеет 2 режима работы: обучение (fit) и применение (predict).

Сборка проекта

перейти в каталог src, создать в нем каталог build, перейти в него и выполнить следующие команды:

cmake ..
make -j4

при этом на компьютере должны быть установлена библиотека Open MP.

Обучение

Формат запуска:

myGboosting fit <file path> [optional parameters]

<file path> - путь к csv файлу, содержащему обучающую выборку

Параметры:

Параметр Описание Значение по умолчанию
column_names путь к файлу, содержащему названия столбцов обучающей выборки
model-path имя файла, в который будет сохранена обученная модель
output-path имя файла, в который будут сохранены прогнозы обученной модели на обучающей выборке
target номер колонки, которая содержит значение целевой переменной последняя колонка(-1)
nthread число параллельных потоков, используемых для обучения 1
delimiter разделитель, используемый в csv-файлах ,
has-header имеет ли входной csv-файл заголовок с названиями столбцов false
iterations максимальное количество деревьев в модели 100
learning-rate темп обучения модели 1.0
depth глубина решающего дерева 6
max_bins количество сплитов в гистограмме для числовых признаков (от 1 до 255) 10
verbose степень подробности выводимой в консоль информации 0
sample_rate вероятность сэмплинга строк для каждого дерева (какую часть датасета использовать) 0.66
min_leaf_count минимальное количество объектов в листовой вершине 1

Применение

Формат запуска:

myGboosting predict <file path> [optional parameters]

<file path> - путь к csv файлу, содержащему тестовую выборку

Параметры:

Параметр Описание Значение по умолчанию
column_names путь к файлу, содержащему названия столбцов тестовой выборки
model-path имя файла, из которого будет считана обученная модель
output-path имя файла, в который будут сохранены прогнозы модели на тестовой выборке
delimiter разделитель, используемый в csv-файлах ,
has-header имеет ли входной csv-файл заголовок с названиями столбцов false
has-target имеет ли входной csv-файл колонку со значениями целевой переменной false
target номер колонки, которая содержит значение целевой переменной последняя колонка(-1)
verbose степень подробности выводимой в консоль информации 0

Архитектура

  • Используются Oblivious Decision trees и гистограммы признаков
  • Параллелизация при обучении происходит при выборе оптимального сплита c помощью библиотеки Open MP
  • для сохранения и загрузки моделей используется библиотека Protobuf

Результаты

Модель проверялась на наборе данных Higgs, была взята train выборка (250000 сэмплов) и поделена на train (20000) и тест (50000). Поскольку последняя колонка - это бинарная классификация, то предсказывалась предпоследняя (вес частицы).

Параметры запуска:

LightGBM

time ./lightgbm objective=mse data=../../myGboosting/testing/datasets/Higgs/train.csv 
num_threads=1 num_iterations=400 max_bin=255 bagging_fraction=0.5 feature_fraction=1.0 bagging_freq=1 
num_leaves=64 learning_rate=0.5 label=31 min_data_in_leaf=1

myGboosting

time ./myGboosting fit  ../testing/datasets/Higgs/train.csv --output=model.pb 
--iterations=400 --depth=6 --learning-rate=0.5 --sample-rate=0.5 --max_bins=255 
--nthreads=1 --verbose=1

Результаты производительности на 1 потоке

Модель проверялась на ноутбуке Macbook Pro 15 2015. Поэтому в наличии есть 4 реальных ядра и 8 виртуальных.

Решение Depth Row sampling Кол-во деревьев Learning Rate Время MSE Train MSE Test
LightGBM 3 0.5 400 0.5 8.185s 1.20136 1.39434
myGboosting 3 0.5 400 0.5 7.420s 1.26377 1.3244
LightGBM 6 0.5 400 0.5 13.352s 0.6469 1.88662
myGboosting 6 0.5 400 0.5 12.062s 1.11156 1.44719
LightGBM 9 0.5 400 0.5 62.043s 0.02141 2.7639
myGboosting 9 0.5 400 0.5 30.145s 0.742787 1.88279
LightGBM 4 0.5 800 0.1 16.285s 1.11651 1.30245
myGboosting 4 0.5 800 0.1 16.161s 1.24526 1.29291
LightGBM 4 0.5 4000 0.02 78.118s 1.1006 1.28318
myGboosting 4 0.5 4000 0.04 76.491s 1.1894 1.28369
LightGBM 4 0.5 200 0.7 5.214s 1.2102 1.55602
myGboosting 4 0.5 200 0.7 4.784s 1.27188 1.35914
LightGBM 6 0.7 400 0.5 13.643s 0.5028 1.7779
myGboosting 6 0.7 400 0.5 13.974s 1.06907 1.4014
LightGBM 6 1.0 400 0.5 13.483s 0.4661 1.61041
myGboosting 6 1.0 400 0.5 16.315s 1.05107 1.35413
LightGBM 9 1.0 400 0.2 55.146s 0.06689 1.42377
myGboosting 9 1.0 400 0.2 35.697s 0.85707 1.3268

Из таблицы видно, что наше решение обгоняет LightGBM в подавляющем большинстве случаев. Странным образом на LightGBM влияет параметр Row sampling. При его уменьшении скорость должна расти, а у него она падает.

Также видно, что темп обучения у ODT отстает от темпа обучения обычных деревьев. При этом качество на тестовой выборке на данных экспериментах у нашей модели лучше, но это вызвано тем, что модель LightGBM успевает переобучиться за заданное число итераций.

Результаты производительности в многопоточном режиме (время)

Решение 1 поток 2 потока 4 потока 6 потоков 8 потоков
LightGBM 13.352s 10.404s 8.727s 9.070s 9.766s
myGboosting 12.062s 8.915s 7.604s 7.304s 7.313s

Параметры запуска такие же, как приведены в начале прошлого пункта (глубина 6, row sampling 0.5, 400 деревьев)

Видно, что после 4 потоков производительность почти не растет. Наше решение обгоняет LightGBM с ростом числа потоков.

Результаты измерений качества на отложенной выборке (test)

Попробовав достаточное количество различных комбинаций параметров, мы увидели, что MSE на отложенной выборке (test) не опускалось ниже 1.283, и LightGBM, и myGboosting приближались к этой границе при росте числа деревьев до 4000 и маленьком learning rate (см. колонку test MSE в таблице выше). Из этого мы можем сделать вывод, что по качеству наше решение не уступает LightGBM, но learning rate для нашего решения нужно выставлять в 2 раза больше, чем для LightGBM.

Ход наших экспериментов

  1. Мы реализовали базовую версию решающего дерева, которая перебирала все возможные сплиты.
  2. Стало понятно, что это работает очень медленно, и мы перешли к гистограммам
  3. Далее, возникла идея перенести гистограммы на уровень фичей и добавить работу с категориальными признаками. Для этого мы реализовали бинаризацию численных признаков и one hot кодирование для категориальных признаков
  4. В ходе замеров производительности выяснилось, что данный подход работает медленно и требует большого количества памяти на больших датасетах (Higgs).
  5. После этого мы перешли к реализации Oblivious Decision Trees (ODT). Мы выбрали вариант ODT, при котором мы выбираем одно значение сплита для всех узлов конкретного уровня, это позволяет хранить дерево как список сплитов (id_признака, № корзины) и список значений в листовых вершинах (2^<Глубина дерева> вершин)
  6. Эксперименты с ODT показали, что MSE уменьшается гораздо медленнее, чем на обычных деревьях. После отладки выяснилось, что наше разбиение на корзины работает не оптимально, и мы взяли его реализацию из LightGBM.
  7. Кроме того, на MSE негативно влиял параметр min_child_weight (минимальное кол-во объектов в листах). Выяснилось, что для ODT оптимальнее строить полные бинарные деревья, независимо от числа объектов в листьях (min_child_weight=1).
  8. Теперь MSE стало вести себя гораздо лучше.
  9. Также одной из особенностей ODT является то, что фичи в одном и том же дереве не повторяются. Мы пробовали отключать это правило и на небольших датасетах это давало лучшее уменьшение MSE, но на больших датасетах MSE уменьшается одинаково, а с включенным правилом алгоритм работает быстрее.
  10. Далее мы перешли к реализации нескольких потоков. После профилирования с помощью Valgrind выяснилось, что основное время работы алгоритма занимает построение гистограмм. Поэтому для приемлемой производительности многопоточного алгоритма оказалось достаточно распараллелить поиск оптимального сплита, в рамках которого у нас строятся гистограммы.
  11. Для реализации параллельных потоков была выбрана библиотека OpenMP, т.к. LightGBM и XGBoost используют ее же. Параллельность была реализована с помощью параллельного цикла for с критической секцией по выбору максимального Gain.
  12. При сравнении производительности с LightGBM была замечена странная особенность: изменение параметра row sampling (сэмплирование датасета для конкретного дерева) не приводит к увеличению производительности LightGBM. При этом наше решение увеличивает свою производительность (что логично, т.к. дереву нужно учиться на меньшем объеме данных).
  13. При сборке на других машинах возникли проблемы со сборкой библиотеки Protobuf, поэтому пришлось написать кастомную сериализацию моделей.

Используемые библиотеки

  1. https://github.com/ben-strasser/fast-cpp-csv-parser
  2. https://github.com/Taywee/args
  3. https://developers.google.com/protocol-buffers/
  4. https://www.openmp.org