Инструкция для быстрого старта
Бинарная классификация на текстовых данных из RuTweetCorp (https://study.mokoron.com/)
отрицательный: 0
положительный: 1
Используются очищенные данные русскоязычного твиттера длинее 100 символов.
RuTweetCorp (https://study.mokoron.com/)
Класс CustomDataset необходим для использования с библиотекой transformers. Наследуется от класса Dataset. В нем определяются 3 обязательные функции: init, len, getitem. основное предназначение - возвращает токенизированные данные в нужном формате.
При инициализации классификатора выполняются следующие действия:
- Скачиваются модель и токенизатор из репозитория huggingface;
- Определяется наличие целевого устройства для вычислений;
- Определяется размерность ембеддингов;
- Задается количество классов;
- Задается количество эпох для обучения.
Для обучения BERT нужно инициализировать несколько вспомогательных элементов:
- DataLoader: нужен для создания батчей;
- Optimizer: оптимизатор градиентного спуска;
- Scheduler: планировщик, нужен для настройки параметров оптимизатора;
- Loss: функция потерь, считаем по ней ошибку модели.
- Обучение для одной эпохи описано в методе fit.
- Данные в цикле батчами генерируются с помощью DataLoader;
- Батч подается в модель;
- На выходе получаем распределение вероятности по классам и значение ошибки;
- Делаем шаг на всех вспомогательных функциях:
- loss.backward: обратное распространение ошибки;
- clip_grad_norm: обрезаем градиенты для предотвращения "взрыва" градиентов;
- optimizer.step: шаг оптимизатора;
- scheduler.step: шаг планировщика;
- optimizer.zero_grad: обнуляем градиенты.
- Проверку на валидационной выборке проводим с помощью метода eval. При этом используем метод torch.no_grad для предотвращения обучения на валидационной выборке.
- Для обучения на нескольких эпохах используется метод train, в котором последовательно вызываются методы fit и eval.
Для предсказания класса для нового текста используется метод predict, который имеет смысл вызывать только после обучения модели.
Метод работает следующим образом:
- Токенизируется входной текст;
- Токенизированный текст подается в модель;
- На выходе получаем вероятности классов;
- Возвращаем метку наиболее вероятного класса.
- Рубцова Ю. Автоматическое построение и анализ корпуса коротких текстов (постов микроблогов) для задачи разработки и тренировки тонового классификатора //Инженерия знаний и технологии семантического веба. – 2012. – Т. 1. – С. 109-116.
- https://habr.com/ru/post/567028/