Skip to content

Commit

Permalink
v0.2.1
Browse files Browse the repository at this point in the history
  • Loading branch information
Tongjilibo committed Sep 5, 2022
1 parent 352a60d commit 6ed4973
Show file tree
Hide file tree
Showing 5 changed files with 155 additions and 3 deletions.
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,14 @@ pip install git+https://www.github.com/Tongjilibo/bert4torch.git
- **丰富示例**:包含[pretrain](https://github.com/Tongjilibo/bert4torch/blob/master/examples/pretrain)[sentence_classfication](https://github.com/Tongjilibo/bert4torch/blob/master/examples/sentence_classfication)[sentence_embedding](https://github.com/Tongjilibo/bert4torch/blob/master/examples/sequence_embedding)[sequence_labeling](https://github.com/Tongjilibo/bert4torch/blob/master/examples/sequence_labeling)[relation_extraction](https://github.com/Tongjilibo/bert4torch/blob/master/examples/relation_extraction)[seq2seq](https://github.com/Tongjilibo/bert4torch/blob/master/examples/seq2seq)等多种解决方案
- **实验验证**:已在公开数据集[实验验证](https://github.com/Tongjilibo/bert4torch/blob/master/examples/Performance.md), 使用如下[examples数据集](https://github.com/Tongjilibo/bert4torch/blob/master/examples/README.md)
- **易用trick**:集成了常见的[trick](https://github.com/Tongjilibo/bert4torch/blob/master/examples/training_trick), 即插即用
- **其他特性**[加载transformers库模型](https://github.com/Tongjilibo/bert4torch/blob/master/examples/others/task_load_transformers_model.py)一起使用;调用方式简洁高效;有训练进度条动态展示;配合torchinfo打印参数量;配合tensorboard记录训练过程;自定义fit过程,满足高阶需求
- **其他特性**[加载transformers库模型](https://github.com/Tongjilibo/bert4torch/blob/master/examples/others/task_load_transformers_model.py)一起使用;调用方式简洁高效;有训练进度条动态展示;配合torchinfo打印参数量;默认Logger和Tensorboard简便记录训练过程;自定义fit过程,满足高阶需求

## 快速上手
- [快速上手教程](https://github.com/Tongjilibo/bert4torch/blob/master/examples/tutorials/Tutorials.md), [教程示例](https://github.com/Tongjilibo/bert4torch/blob/master/examples/tutorials), [实战示例](https://github.com/Tongjilibo/bert4torch/blob/master/examples)
- [bert4torch介绍(知乎)](https://zhuanlan.zhihu.com/p/486329434)[bert4torch快速上手(知乎)](https://zhuanlan.zhihu.com/p/508890807), [bert4torch又双叒叕更新啦](https://zhuanlan.zhihu.com/p/560885427?)

## 版本说明
- **v0.2.1**:兼容torch<=1.7.1的torch.div无rounding_mode, 增加自定义metrics,支持断点续训,增加默认Logger和Tensorboard日志
- **v0.2.0**:兼容torch<1.9.0的缺失take_along_dim,修复bart中位置向量514的问题,修复Sptokenizer对符号不转换,打印Epoch开始的时间戳,增加parallel_apply
- **v0.1.9**:增加mixup/manifold_mixup/temporal_ensembling策略, 修复pgd策略param.grad为空的问题,修改tokenizer支持批量
- **v0.1.8**:修复原来CRF训练中loss陡增的问题,修复xlnet的token_type_ids输入显存占用大的问题
Expand Down
2 changes: 1 addition & 1 deletion bert4torch/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#! -*- coding: utf-8 -*-


__version__ = '0.2.0'
__version__ = '0.2.1'
2 changes: 2 additions & 0 deletions examples/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -110,8 +110,10 @@
- [task_nl2sql_baseline.py](https://github.com/Tongjilibo/bert4torch/blob/master/examples/others/task_nl2sql_baseline.py)[追一科技2019年NL2SQL挑战赛的一个Baseline](https://kexue.fm/archives/6771)

### 教程
- [Tutorials](https://github.com/Tongjilibo/bert4torch/blob/master/examples/tutorials/Tutorials):教程说明文档。
- [tutorials_custom_fit_progress.py](https://github.com/Tongjilibo/bert4torch/blob/master/examples/tutorials/tutorials_custom_fit_progress.py):教程,自定义训练过程fit函数(集成了训练进度条展示),可用于满足如半精度,梯度裁剪等高阶需求。
- [tutorials_load_transformers_model.py](https://github.com/Tongjilibo/bert4torch/blob/master/examples/tutorials/tutorials_load_transformers_model.py):教程,加载transformer包中模型,可以使用bert4torch中继承的对抗训练等trick。
- [tutorials_small_tips.py](https://github.com/Tongjilibo/bert4torch/blob/master/examples/tutorials/tutorials_small_tips.py):教程,常见的一些tips集合。

## 用到的数据集
| 数据集名称 | 用途 | 下载链接 |
Expand Down
149 changes: 149 additions & 0 deletions examples/tutorials/tutorials_small_tips.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
#! -*- coding:utf-8 -*-
# 以文本分类为例,展示部分tips的使用方法
# torchinfo打印参数,自定义metrics, 断点续训,默认Logger和Tensorboard

from bert4torch.tokenizers import Tokenizer
from bert4torch.models import build_transformer_model, BaseModel
from bert4torch.snippets import sequence_padding, Callback, Logger, Tensorboard, text_segmentate, ListDataset, seed_everything, get_pool_emb
import torch.nn as nn
import torch
import torch.optim as optim
from torch.utils.data import DataLoader
from torchinfo import summary
import os

maxlen = 256
batch_size = 16
config_path = 'F:/Projects/pretrain_ckpt/bert/[google_tf_base]--chinese_L-12_H-768_A-12/bert_config.json'
checkpoint_path = 'F:/Projects/pretrain_ckpt/bert/[google_tf_base]--chinese_L-12_H-768_A-12/pytorch_model.bin'
dict_path = 'F:/Projects/pretrain_ckpt/bert/[google_tf_base]--chinese_L-12_H-768_A-12/vocab.txt'
device = 'cuda' if torch.cuda.is_available() else 'cpu'
choice = 'train' # train表示训练,infer表示推理

# 固定seed
seed_everything(42)

# 建立分词器
tokenizer = Tokenizer(dict_path, do_lower_case=True)

# 加载数据集
class MyDataset(ListDataset):
@staticmethod
def load_data(filenames):
"""加载数据,并尽量划分为不超过maxlen的句子
"""
D = []
seps, strips = u'\n。!?!?;;,, ', u';;,, '
for filename in filenames:
with open(filename, encoding='utf-8') as f:
for l in f:
text, label = l.strip().split('\t')
for t in text_segmentate(text, maxlen - 2, seps, strips):
D.append((t, int(label)))
return D

def collate_fn(batch):
batch_token_ids, batch_segment_ids, batch_labels = [], [], []
for text, label in batch:
token_ids, segment_ids = tokenizer.encode(text, maxlen=maxlen)
batch_token_ids.append(token_ids)
batch_segment_ids.append(segment_ids)
batch_labels.append([label])

batch_token_ids = torch.tensor(sequence_padding(batch_token_ids), dtype=torch.long, device=device)
batch_segment_ids = torch.tensor(sequence_padding(batch_segment_ids), dtype=torch.long, device=device)
batch_labels = torch.tensor(batch_labels, dtype=torch.long, device=device)
return [batch_token_ids, batch_segment_ids], batch_labels.flatten()

# 加载数据集
train_dataloader = DataLoader(MyDataset(['F:/Projects/data/corpus/sentence_classification/sentiment/sentiment.train.data']), batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
valid_dataloader = DataLoader(MyDataset(['F:/Projects/data/corpus/sentence_classification/sentiment/sentiment.valid.data']), batch_size=batch_size, collate_fn=collate_fn)
test_dataloader = DataLoader(MyDataset(['F:/Projects/data/corpus/sentence_classification/sentiment/sentiment.test.data']), batch_size=batch_size, collate_fn=collate_fn)

# 定义bert上的模型结构
class Model(BaseModel):
def __init__(self, pool_method='cls') -> None:
super().__init__()
self.pool_method = pool_method
self.bert = build_transformer_model(config_path=config_path, checkpoint_path=checkpoint_path, with_pool=True)
self.dropout = nn.Dropout(0.1)
self.dense = nn.Linear(self.bert.configs['hidden_size'], 2)

def forward(self, token_ids, segment_ids):
hidden_states, pooling = self.bert([token_ids, segment_ids])
pooled_output = get_pool_emb(hidden_states, pooling, token_ids.gt(0).long(), self.pool_method)
output = self.dropout(pooled_output)
output = self.dense(output)
return output
model = Model().to(device)
summary(model, input_data=next(iter(train_dataloader))[0])

def acc(y_pred, y_true):
y_pred = torch.argmax(y_pred, dim=-1)
return torch.sum(y_pred.eq(y_true)).item() / y_true.numel()

# 定义使用的loss和optimizer,这里支持自定义
optimizer = optim.Adam(model.parameters(), lr=2e-5)

if os.path.exists('last_model.pt'):
model.load_weights('last_model.pt') # 加载模型权重
if os.path.exists('last_steps.pt'):
model.load_steps_params('last_steps.pt') # 加载训练进度参数,断点续训使用
if os.path.exists('last_model.pt'):
state_dict = torch.load('last_optimizer.pt', map_location='cpu') # 加载优化器,断点续训使用
optimizer.load_state_dict(state_dict)

model.compile(
loss=nn.CrossEntropyLoss(),
optimizer=optimizer,
metrics={'acc': acc}
)

class Evaluator(Callback):
"""评估与保存
"""
def __init__(self):
self.best_val_acc = 0.

def on_epoch_end(self, global_step, epoch, logs=None):
val_acc = self.evaluate(valid_dataloader)
test_acc = self.evaluate(test_dataloader)
logs['val/acc'] = val_acc
logs['test/acc'] = val_acc
if val_acc > self.best_val_acc:
self.best_val_acc = val_acc
# model.save_weights('best_model.pt')
print(f'val_acc: {val_acc:.5f}, test_acc: {test_acc:.5f}, best_val_acc: {self.best_val_acc:.5f}\n')

model.save_weights('last_model.pt', prefix=None) # 保存模型权重
model.save_steps_params('last_steps.pt') # 保存训练进度参数,当前的epoch和step,断点续训使用
torch.save(optimizer.state_dict(), 'last_optimizer.pt') # 保存优化器,断点续训使用

# 定义评价函数
def evaluate(self, data):
total, right = 0., 0.
for x_true, y_true in data:
y_pred = model.predict(x_true).argmax(axis=1)
total += len(y_true)
right += (y_true == y_pred).sum().item()
return right / total

def inference(texts):
'''单条样本推理
'''
for text in texts:
token_ids, segment_ids = tokenizer.encode(text, maxlen=maxlen)
token_ids = torch.tensor(token_ids, dtype=torch.long, device=device)[None, :]
segment_ids = torch.tensor(segment_ids, dtype=torch.long, device=device)[None, :]

logit = model.predict([token_ids, segment_ids])
y_pred = torch.argmax(torch.softmax(logit, dim=-1)).cpu().numpy()
print(text, ' ----> ', y_pred)

if __name__ == '__main__':
if choice == 'train':
evaluator = Evaluator()
model.fit(train_dataloader, epochs=10, steps_per_epoch=100, callbacks=[evaluator, Logger('test.log'), Tensorboard('./')])
else:
model.load_weights('best_model.pt')
inference(['我今天特别开心', '我今天特别生气'])
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

setup(
name='bert4torch',
version='0.2.0',
version='0.2.1',
description='an elegant bert4torch',
long_description='bert4torch: https://github.com/Tongjilibo/bert4torch',
license='MIT Licence',
Expand Down

0 comments on commit 6ed4973

Please sign in to comment.