Инструкция для быстрого старта
Бинарная классификация на текстовых данных из RuTweetCorp (https://study.mokoron.com/)
отрицательный: 0
положительный: 1
Когда я осваиваю какой-либо новый подход, обычно находятся статьии туториалы, в которых все чрезвычайно подробно, от первичной обработки данных, до построения кривых обучения. Мне же всегда хотелось быстро понять суть подходаы и сразу начать использовать имеющиеся наработки, а не мучительно разбираться с простыней чужого кода. Поэтому я решил сделать по возможности максимально простое и прозрачное решение, которое не будет перегружено лишним кодом, в котором можно легко и быстро разобраться.
Про BERT я писать ничего не буду - про него полно отличных статей, так что просто будем использовать его в качестве черного ящика.
Используются очищенные данные русскоязычного твиттера длинее 100 символов.
RuTweetCorp (https://study.mokoron.com/)
Класс CustomDataset необходим для использования с библиотекой transformers. Наследуется от класса Dataset. В нем определяются 3 обязательные функции: init, len, getitem. основное предназначение - возвращает токенизированные данные в нужном формате.
При инициализации классификатора выполняются следующие действия:
- Скачиваются модель и токенизатор из репозитория huggingface;
- Определяется наличие целевого устройства для вычислений;
- Определяется размерность ембеддингов;
- Задается количество классов;
- Задается количество эпох для обучения.
Для обучения BERT нужно инициализировать несколько вспомогательных элементов:
- DataLoader: нужен для создания батчей;
- Optimizer: оптимизатор градиентного спуска;
- Scheduler: планировщик, нужен для настройки параметров оптимизатора;
- Loss: функция потерь, считаем по ней ошибку модели.
- Обучение для одной эпохи описано в методе fit.
- Данные в цикле батчами генерируются с помощью DataLoader;
- Батч подается в модель;
- На выходе получаем распределение вероятности по классам и значение ошибки;
- Делаем шаг на всех вспомогательных функциях:
- loss.backward: обратное распространение ошибки;
- clip_grad_norm: обрезаем градиенты для предотвращения "взрыва" градиентов;
- optimizer.step: шаг оптимизатора;
- scheduler.step: шаг планировщика;
- optimizer.zero_grad: обнуляем градиенты.
- Проверку на валидационной выборке проводим с помощью метода eval. При этом используем метод torch.no_grad для предотвращения обучения на валидационной выборке.
- Для обучения на нескольких эпохах используется метод train, в котором последовательно вызываются методы fit и eval.
Для предсказания класса для нового текста используется метод predict, который имеет смысл вызывать только после обучения модели.
Метод работает следующим образом:
- Токенизируется входной текст;
- Токенизированный текст подается в модель;
- На выходе получаем вероятности классов;
- Возвращаем метку наиболее вероятного класса.
Хотелось максимально просто, но все равно получилось как-то объемно. Прошу понять и простить. Пис!