Устройство утилиты
При чтении датасета имена фичей заменяются на их хэш (аналогично тому, как это делает vw), хэшфункция принимает 2^b
значений, где b
задаётся при запуске.
Есть два режима обучения: градиентный спуск (SGD или AdaGrad) и ALS.
При обучении в несколько потоков применяется Hogwild (как в градиентном спуске, так и в ALS).
Возможно чтение датасетов в форматах csv и libsvm.
Градиентный спуск
Возможна оптимизация трёх loss-функций: mse, mae и logistic
SGD
Размер шага вычисляется по следующей формуле (взято у vw):
λ d^k / t,
где λ
— learning rate, d
— learning rate decay, k
— число полных эпох обучения, t
— число итераций (просмотренных записей) обучения.
AdaGrad
Размер шага для параметра i
вычисляется по формуле
λ d^k / t_i,
где t_i = sqrt(sum g_{i,j}^2)
, g_{i,j}
— градиент i
-го параметра на j
-й итерации обучения.
ALS
Алгоритм реализован по статье от libFM.
Можно попытаться сделать ALS для произвольной loss-функции, если посчитать вторую производную loss-функции по предсказанию. Тогда формула (22) в статье будет выглядеть так:
Эта штука взрывается, если вторая производная маленькая (mae или logistic при предсказаниях близких к правильному), но если её ограничить снизу константой вроде 0.1~1.0, то вроде работает. На нормальные эксперименты времени не хватило.
Установка
Утилита написана на Rust.
Сборка утилиты:
cargo build --release
Запуск
Команда запуска
cargo run -q --release -- <utility options>
Опции:
-h,--help Справка по опциям утилиты
-p,--predict Запуск в режиме предсказания. По умолчанию утилита запускается в режиме обучения
-d,--data DATA Путь к датасету для обучения или вычисления предсказаний
--data_type DATA_TYPE Тип датасета. Возможные значения: csv, libsvm. Значение по умолчанию: libsvm
-t,--target TARGET Поле таргета для csv датасетов
-m,--model MODEL Путь к модели, куда записывается обученная модель или откуда берётся модель для предсказания
-o,--output OUTPUT Путь к файлу для записи вычисленных предсказаний. При отсутствии предсказания не выводятся.
--opt OPT Тип оптимизатора. Возможные значения: sgd, adagrad, als. Значение по умолчанию: adagrad.
--loss LOSS Лосс функция. Возможные значения: mse, logistic, mae. Значение по умолчанию: mse
-i,--iterations ITERATIONS Число эпох обучения. Значение по умолчанию: 10
-b,--bits BITS Число бит хэш-функции. Значение по умолчанию: 18
-k,--factors_number Число факторов в модели: Значение по умолчанию: 10
--l2 L2 Значение l2-регуляризатора. Значение по умолчанию: 1e-5
--lr LR Значение learning rate для градиентного спуска
--decay DECAY Значение learning rate decay для градиентного спуска
-j,--jobs JOBS Число потоков
Примеры:
Обучение с AdaGrad:
cargo run -q --release -- -d datasets/train_20m_wo_time.csv --data_type csv --target rating -m model --loss mse -i 20 -j 8
Обучение с ALS и вычислением скора на тесте после каждой итерации:
cargo run -q --release -- -d datasets/train_20m_wo_time.csv --test_data datasets/test_20m_wo_time.csv --data_type csv --target rating -m model --loss mse -i 20 -j 8 --opt als
Вычисление скора на тесте:
cargo run -q --release -- -p -d datasets/test_20m_wo_time.csv --data_type csv --target rating -m model --loss mae
Бенчмарки
Movielens
Для vowpal wabbit использовался этот бенчмарк с заменой датасета на 20m. Результаты:
linear test MAE is 0.652
lrq test MAE is 0.639
lrqdropout test MAE is 0.608
lrqdropouthogwild test MAE is 0.787
Наша утилита:
AdaGrad, 8 факторов:
train mse: 0.592
test mse: 0.677
train mae: 0.586
test mae: 0.627
Минимум на тесте около 10-й итерации.
Минимум ошибки на тесте достигается при ~8 факторах.
ALS, 8 факторов, l2 1e-6 (~15), 20 итераций
train mse: 0.593
test mse: 0.672
train mae: 0.588
test mae: 0.625
ALS, 12 факторов, l2 1e-6 (~15), 20 итераций
train mse: 0.566
test mse: 0.669
train mae: 0.573
test mae: 0.622
libFM:
SGD, 8 факторов, lr 0.01, 17 итераций
train mse: 0.756
test mse: 0.811
SGD, 8 факторов, lr 0.01, 60 итераций
train mse: 0.749
test mse: 0.807
ALS, 8 факторов, l2 10, 40 итераций
train mse: 0.758
test mse: 0.804
ALS, 8 факторов, l2 10, 40 итераций
train mse: 0.736
test mse: 0.799
MCMC, 8 факторов, 70 итераций
train mse: 0.768
test mse: 0.798
Время работы
Movielens
Загрузка train и test датасетов, 8 факторов:
our fm: 10 с.
libFM: 36 с.
Загрузка датасетов + 10 итераций AdaGrad / SGD, 8 факторов:
our fm: 164 с.
our fm, 4 потока: 65 c. (206 с. проц. время)
our fm, 8 потоков: 50 с. (268 с. проц. время)
libFM: 121 c.
Загрузка датасетов + 10 итераций ALS, 8 факторов:
our fm: 295 с.
our fm, 4 потока: 112 c. (374 с. проц. время)
our fm, 8 потоков: 105 с. (661 с. проц. время)
libFM: 295 с.