/bert4classification

Finetuning BERT for Russian twits classification

Primary LanguageJupyter Notebook

[RU|EN]

Open In Colab

BERT для задачи классификации

Инструкция для быстрого старта
Бинарная классификация на текстовых данных из RuTweetCorp (https://study.mokoron.com/)
отрицательный: 0
положительный: 1

Структура

Данные для обучения

Используются очищенные данные русскоязычного твиттера длинее 100 символов.
RuTweetCorp (https://study.mokoron.com/)

CustomDataset

Класс CustomDataset необходим для использования с библиотекой transformers. Наследуется от класса Dataset. В нем определяются 3 обязательные функции: init, len, getitem. основное предназначение - возвращает токенизированные данные в нужном формате.

Initialize

При инициализации классификатора выполняются следующие действия:

  • Скачиваются модель и токенизатор из репозитория huggingface;
  • Определяется наличие целевого устройства для вычислений;
  • Определяется размерность ембеддингов;
  • Задается количество классов;
  • Задается количество эпох для обучения.

Preparation

Для обучения BERT нужно инициализировать несколько вспомогательных элементов:

  • DataLoader: нужен для создания батчей;
  • Optimizer: оптимизатор градиентного спуска;
  • Scheduler: планировщик, нужен для настройки параметров оптимизатора;
  • Loss: функция потерь, считаем по ней ошибку модели.

Train

  • Обучение для одной эпохи описано в методе fit.
    • Данные в цикле батчами генерируются с помощью DataLoader;
    • Батч подается в модель;
    • На выходе получаем распределение вероятности по классам и значение ошибки;
    • Делаем шаг на всех вспомогательных функциях:
      • loss.backward: обратное распространение ошибки;
      • clip_grad_norm: обрезаем градиенты для предотвращения "взрыва" градиентов;
      • optimizer.step: шаг оптимизатора;
      • scheduler.step: шаг планировщика;
      • optimizer.zero_grad: обнуляем градиенты.
  • Проверку на валидационной выборке проводим с помощью метода eval. При этом используем метод torch.no_grad для предотвращения обучения на валидационной выборке.
  • Для обучения на нескольких эпохах используется метод train, в котором последовательно вызываются методы fit и eval.

Inference

Для предсказания класса для нового текста используется метод predict, который имеет смысл вызывать только после обучения модели.
Метод работает следующим образом:

  • Токенизируется входной текст;
  • Токенизированный текст подается в модель;
  • На выходе получаем вероятности классов;
  • Возвращаем метку наиболее вероятного класса.

Ссылки

  • Рубцова Ю. Автоматическое построение и анализ корпуса коротких текстов (постов микроблогов) для задачи разработки и тренировки тонового классификатора //Инженерия знаний и технологии семантического веба. – 2012. – Т. 1. – С. 109-116.
  • https://habr.com/ru/post/567028/