Skip to content

shitkov/bert4classification

Repository files navigation

[RU|EN]

Open In Colab

BERT для задачи классификации

Инструкция для быстрого старта
Бинарная классификация на текстовых данных из RuTweetCorp (https://study.mokoron.com/)
отрицательный: 0
положительный: 1

Зачем?

Когда я осваиваю какой-либо новый подход, обычно находятся статьии туториалы, в которых все чрезвычайно подробно, от первичной обработки данных, до построения кривых обучения. Мне же всегда хотелось быстро понять суть подходаы и сразу начать использовать имеющиеся наработки, а не мучительно разбираться с простыней чужого кода. Поэтому я решил сделать по возможности максимально простое и прозрачное решение, которое не будет перегружено лишним кодом, в котором можно легко и быстро разобраться.
Про BERT я писать ничего не буду - про него полно отличных статей, так что просто будем использовать его в качестве черного ящика.

Структура

Данные для обучения

Используются очищенные данные русскоязычного твиттера длинее 100 символов.
RuTweetCorp (https://study.mokoron.com/)

CustomDataset

Класс CustomDataset необходим для использования с библиотекой transformers. Наследуется от класса Dataset. В нем определяются 3 обязательные функции: init, len, getitem. основное предназначение - возвращает токенизированные данные в нужном формате.

Initialize

При инициализации классификатора выполняются следующие действия:

  • Скачиваются модель и токенизатор из репозитория huggingface;
  • Определяется наличие целевого устройства для вычислений;
  • Определяется размерность ембеддингов;
  • Задается количество классов;
  • Задается количество эпох для обучения.

Preparation

Для обучения BERT нужно инициализировать несколько вспомогательных элементов:

  • DataLoader: нужен для создания батчей;
  • Optimizer: оптимизатор градиентного спуска;
  • Scheduler: планировщик, нужен для настройки параметров оптимизатора;
  • Loss: функция потерь, считаем по ней ошибку модели.

Train

  • Обучение для одной эпохи описано в методе fit.
    • Данные в цикле батчами генерируются с помощью DataLoader;
    • Батч подается в модель;
    • На выходе получаем распределение вероятности по классам и значение ошибки;
    • Делаем шаг на всех вспомогательных функциях:
      • loss.backward: обратное распространение ошибки;
      • clip_grad_norm: обрезаем градиенты для предотвращения "взрыва" градиентов;
      • optimizer.step: шаг оптимизатора;
      • scheduler.step: шаг планировщика;
      • optimizer.zero_grad: обнуляем градиенты.
  • Проверку на валидационной выборке проводим с помощью метода eval. При этом используем метод torch.no_grad для предотвращения обучения на валидационной выборке.
  • Для обучения на нескольких эпохах используется метод train, в котором последовательно вызываются методы fit и eval.

Inference

Для предсказания класса для нового текста используется метод predict, который имеет смысл вызывать только после обучения модели.
Метод работает следующим образом:

  • Токенизируется входной текст;
  • Токенизированный текст подается в модель;
  • На выходе получаем вероятности классов;
  • Возвращаем метку наиболее вероятного класса.

Заключение

Хотелось максимально просто, но все равно получилось как-то объемно. Прошу понять и простить. Пис!

About

Finetuning BERT for classification task

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published