Задача представляет собой kaggle-соревнование по классификации китов на изображениях.
Цель работы: решить данную задачу, используя подход mertic learning
, а также разработать демо-сервис для инференса изображений.
Репозиторий имеет следующую структуру:
api
- папка, содержащая реализацию демо-сервиса (более подробная информация по запуску сервиса находится ниже)examples
- папка с подготовленными примерами для инференсаmodels
- в данной папке находятся все необходимые компоненты для инференса (модели для получения эмбеддингов, knn, label encoder и др.)app.py
- скрипт запуска веб-сервиса
data
- папка, где необходимо расположить данныеtrain, test, train.csv
experiments
- папка для локального логгирования артефактов обучения (создается автоматически при последовательном запуске пайплайна обучения)src
- скрипты для преобработки данных, обучения моделей, рассчета метрик, обучения knn, получения эмбеддингов и создания сабмита (submission.csv)config.py
- конфиг-класс, где задаются параметры обучения
Датасет состоит из двух папок с изображениями китов:
- train - 25361 изображений
- test - 7960 изображений
Также приложен файл train.csv
, который содержит информацию о классе кита на каждом изображении. Всего в наборе данных представлено 5005 уникальных видов китов, из которых обнаружено:
- чаще всего встречается класс
new_whale
- 9664 из 25361 изображений - 2073 класса содержат только одно изображение
Для обучения и валидации пайплайна набор данных был разделен на две части, причем в обучающую выборку вошли:
- все классы (2073 класса), содержащие только одно изображение на класс;
- 7731 изображений из класса new_whale, остальные 1933 - в валидационную выборку;
- все остальные классы были разделены с помощью метода StratifiedKFold (тут понял, что ошибся, так как изображения разделились поровну между train и val выборками - 2931 уникальных классов по 6812 изображений на выборку).
К настоящему времени в качестве энкодеров для получения эмбеддингов были опробованы следующие модели:
- непредобученный
resnet18
из torchvision; - предобученный
ViT
(Visual Transformer) из hugging-face (веса -google/vit-base-patch16-224
)
На выходе каждой модели получал эмбеддинг размера 512.
В качестве лоссов использовал реализации из библиотеки pytorch-metric-learning
. Были опробованы два лосса:
- ArcFace
- ProxyAnchorLoss
Забегу наперед и скажу, что лучше всего себя показали модели, обученные с помощью ProxyAnchorLoss.
Для запуска всего пайплайна от обработки данных до создании сабмишн-файла нужно придерживаться следующей последовательности:
- Запускаем
splits.py
для разделения данных наtrain_split.csv
иval_split.csv
. - Указываем в
config.py
необходимые параметры обучения. - Запускаем
train.py
для обучения модели. Получаем чекпоинты в папкеexperiments/{LOSS_NAME}/{MODEL_NAME}
. - С помощью
embeddings.py
получаем эмбеддинги для каждого изображения из папки train (сохраняются также вexperiments/{LOSS_NAME}/{MODEL_NAME}
в numpy-формате.npy
). - Обучаем с помощью
knn.py
классификатор для поиска наиближайших эмбеддингов. Сохраняем модель вexperiments/{LOSS_NAME}/{MODEL_NAME}
для последующего использования. - Запускаем
submission.py
для получения submission.csv. - Делаем сабмит на платформу:
kaggle competitions submit -c humpback-whale-identification -f experiments/proxy/ViT/knn_submission.csv -m "vit knn submission"
.
В качестве метрики качества на валидации использовалась метрика Recall: R@1, R@2, R@4, R@8, R@16, R@32.
Со временем в пайплайн обучения было добавлено логгирование метрик в WandB. Ссылка: https://wandb.ai/cv-itmo/whale_classification?workspace=user-dmitryai
Результаты основных экспериментов приведены в таблице ниже (было проведено еще несколько, но они не указаны в связи неуспешности).
Модель | R@1 | Leaderboard Score |
---|---|---|
resnet18_arcface | 0.104 | - |
resnet18_proxy | 0.105 | 0.20022 |
vit_proxy | 0.304 | 0.40347 |
Результаты показывают, что на данный момент лучше всех себя продемонстрировал ViT в связке с ProxyAnchorLoss.
Для реализации веб-сервиса был использован фреймворк Gradio
(ссылка на репозиторий), который позволяет реализовывать веб-приложения для демо-решений и production систем.
Веб-сервис позволяет опробовать две модели: resnet18_proxy
и vit_proxy
.
Для запуска веб-сервиса необходимо скачать архив (ссылка на гугл-диск) с необходимыми артефактами и моделями и распаковать его в папке api/models/
.
Запуск производится из корня репозитория командой python api/app.py
.
Перед этим возможно потребуется исполнение команды export PYTHONPATH="${PYTHONPATH}":pwd
.
В левой части можно загрузить изображение для инференса или выбрать из примеров ниже, где также указаны классы изображений. Справа выводится результирующая таблица с топ-5 ближайшими классами и минимальным косинусным расстоянием.
Изначально был опробован пайплайн, реализованный только с помощью компонент библиотеки pytorch-metric-learning, однако он давал низкие метрики, в следствие чего пайплайн был переписан с самостоятельной реализацией обучающего цикла и др.
В один момент были трудности с обучением ViT, так как пайплайн был изначально реализован под resnet, поэтому потребовалось внесение изменений в обучающий цикл и подсчет метрики Recall.
Также было опробовано два подхода к непосредственной классификации изображений:
- обучение эмбеддингов с помощью knn для получения k-ближайших соседей;
- усреднение эмбеддингов изображений по классам и поиск ближайшего класса путем расчета косинусного расстояния между инференсным изображением и base-файлом с усредненными векторами. Однако данный подход не зашел и показал практически нулевой скор на лидерборде (возможно, стоит более детально проверить этот пайплайн).
Что можно сделать еще:
- эксперименты с моделями и лоссами;
- эксперименты с knn;
- поправить разделение данных на train и val;
- отрефакторить код.