Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PaddlePaddle Hackathon] 55题 提交 #1133

Merged
merged 45 commits into from
Dec 10, 2021
Merged
Show file tree
Hide file tree
Changes from 33 commits
Commits
Show all changes
45 commits
Select commit Hold shift + click to select a range
0808e64
add roberta model
nosaydomore Oct 8, 2021
af30b0a
rollback
nosaydomore Oct 8, 2021
acc0ad4
.
nosaydomore Oct 8, 2021
c6dcf08
Merge branch 'develop' of https://github.com/PaddlePaddle/PaddleNLP i…
nosaydomore Oct 8, 2021
e44930d
upd roberta convert
nosaydomore Oct 8, 2021
fca1c9c
upd
nosaydomore Oct 8, 2021
e9f185b
upd
nosaydomore Oct 8, 2021
654c978
upd compare
nosaydomore Oct 9, 2021
00e786f
Merge branch 'develop' of https://github.com/PaddlePaddle/PaddleNLP i…
nosaydomore Oct 9, 2021
fcf5ef6
upd unitest
nosaydomore Oct 10, 2021
4b2328e
upd unnitest
nosaydomore Oct 10, 2021
e4eae49
upd roberta convert
nosaydomore Oct 10, 2021
66d4fc1
Merge branch 'develop' of https://github.com/PaddlePaddle/PaddleNLP i…
nosaydomore Oct 10, 2021
819bfca
clean code
nosaydomore Oct 10, 2021
d1eedfe
upd tokenizer config
nosaydomore Oct 10, 2021
c369e6c
upd model_config
nosaydomore Oct 10, 2021
f0162dd
Merge branch 'develop' into task_55
nosaydomore Oct 11, 2021
9e08918
Merge branch 'develop' of https://github.com/PaddlePaddle/PaddleNLP i…
nosaydomore Oct 14, 2021
0efd480
add pretrain weight file json
nosaydomore Oct 14, 2021
2cdaa7d
Merge branch 'task_55' of https://github.com/nosaydomore/PaddleNLP in…
nosaydomore Oct 14, 2021
5299048
Merge branch 'develop' of https://github.com/PaddlePaddle/PaddleNLP i…
nosaydomore Oct 14, 2021
981c886
Merge branch 'develop' into task_55
yingyibiao Oct 18, 2021
0157743
Merge branch 'develop' of https://github.com/PaddlePaddle/PaddleNLP i…
nosaydomore Oct 19, 2021
2a30ef5
upd pretrained_model_rst
nosaydomore Oct 19, 2021
e7f9047
Merge branch 'develop' of https://github.com/PaddlePaddle/PaddleNLP i…
nosaydomore Oct 19, 2021
d1842cb
Merge branch 'task_55' of https://github.com/nosaydomore/PaddleNLP in…
nosaydomore Oct 19, 2021
3adf20f
rename pretrain model name
nosaydomore Oct 19, 2021
aacccd0
upd tokenizer pretrained_resource_files_map
nosaydomore Oct 19, 2021
132cb5f
upd tokenizer pretrained_resource_files_map
nosaydomore Oct 19, 2021
715576a
upd tokenizer info
nosaydomore Oct 19, 2021
ab2166d
upd model config
nosaydomore Oct 27, 2021
60897f4
Merge branch 'develop' into task_55
nosaydomore Oct 27, 2021
47b0eca
Merge branch 'develop' into task_55
yingyibiao Nov 5, 2021
068153e
Merge branch 'develop' into task_55
yingyibiao Dec 3, 2021
cc124f1
Merge branch 'PaddlePaddle:develop' into task_55
nosaydomore Dec 6, 2021
de92d71
Merge branch 'PaddlePaddle:develop' into task_55
nosaydomore Dec 7, 2021
53e0448
Merge branch 'develop' into task_55
yingyibiao Dec 8, 2021
31ab67d
Merge branch 'PaddlePaddle:develop' into task_55
nosaydomore Dec 8, 2021
1442c14
unify tokenizer
nosaydomore Dec 9, 2021
c20ff2b
unify tokenizer
nosaydomore Dec 9, 2021
556d73b
fix conflict
nosaydomore Dec 9, 2021
7da5698
Merge branch 'develop' into task_55
yingyibiao Dec 10, 2021
8fdafcb
Update tokenizer.py
yingyibiao Dec 10, 2021
1696221
Update README.md
yingyibiao Dec 10, 2021
04a7d83
Merge branch 'develop' into task_55
yingyibiao Dec 10, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
214 changes: 214 additions & 0 deletions community/nosaydomore/convert_roberta.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,214 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from collections import OrderedDict
import argparse
import paddle
import torch
import os
import json

from paddle.utils.download import get_path_from_url

huggingface_to_paddle = {
"embeddings.LayerNorm": "embeddings.layer_norm",
"encoder.layer": "encoder.layers",
"attention.self.query": "self_attn.q_proj",
"attention.self.key": "self_attn.k_proj",
"attention.self.value": "self_attn.v_proj",
"attention.output.dense": "self_attn.out_proj",
"intermediate.dense": "linear1",
"output.dense": "linear2",
"attention.output.LayerNorm": "norm1",
"output.LayerNorm": "norm2",
"qa_outputs": 'classifier',
'lm_head.bias': 'lm_head.decoder.bias'
}

convert_model_name_list = [
"roberta-base",
"roberta-large",
"deepset/roberta-base-squad2",
"uer/roberta-base-finetuned-chinanews-chinese",
"sshleifer/tiny-distilroberta-base",
"uer/roberta-base-finetuned-cluener2020-chinese",
"uer/roberta-base-chinese-extractive-qa",
]

link_template = "https://huggingface.co/{}/resolve/main/pytorch_model.bin"

pretrained_init_configuration = {
"roberta-base": {
"attention_probs_dropout_prob": 0.1,
"layer_norm_eps": 1e-05,
"hidden_act": "gelu",
"hidden_dropout_prob": 0.1,
"hidden_size": 768,
"initializer_range": 0.02,
"intermediate_size": 3072,
"max_position_embeddings": 514,
"num_attention_heads": 12,
"num_hidden_layers": 12,
"pad_token_id": 1,
"type_vocab_size": 1,
"vocab_size": 50265
},
"roberta-large": {
"attention_probs_dropout_prob": 0.1,
"hidden_act": "gelu",
"hidden_dropout_prob": 0.1,
"hidden_size": 1024,
"initializer_range": 0.02,
"intermediate_size": 4096,
"max_position_embeddings": 514,
"num_attention_heads": 16,
"num_hidden_layers": 24,
"pad_token_id": 1,
"type_vocab_size": 1,
"layer_norm_eps": 1e-05,
"vocab_size": 50265
},
"deepset/roberta-base-squad2": {
"layer_norm_eps": 1e-05,
"attention_probs_dropout_prob": 0.1,
"hidden_act": "gelu",
"hidden_dropout_prob": 0.1,
"hidden_size": 768,
"initializer_range": 0.02,
"intermediate_size": 3072,
"max_position_embeddings": 514,
"num_attention_heads": 12,
"num_hidden_layers": 12,
"pad_token_id": 1,
"type_vocab_size": 1,
"vocab_size": 50265
},
"uer/roberta-base-finetuned-chinanews-chinese": {
"attention_probs_dropout_prob": 0.1,
"hidden_act": "gelu",
"hidden_dropout_prob": 0.1,
"hidden_size": 768,
"initializer_range": 0.02,
"intermediate_size": 3072,
"layer_norm_eps": 1e-12,
"max_position_embeddings": 512,
"num_attention_heads": 12,
"num_hidden_layers": 12,
"pad_token_id": 0,
"type_vocab_size": 2,
"vocab_size": 21128
},
"sshleifer/tiny-distilroberta-base": {
"attention_probs_dropout_prob": 0.1,
"hidden_act": "gelu",
"hidden_dropout_prob": 0.1,
"hidden_size": 2,
"initializer_range": 0.02,
"intermediate_size": 2,
"layer_norm_eps": 1e-05,
"max_position_embeddings": 514,
"num_attention_heads": 2,
"num_hidden_layers": 2,
"pad_token_id": 1,
"type_vocab_size": 1,
"vocab_size": 50265
},
"uer/roberta-base-finetuned-cluener2020-chinese": {
"attention_probs_dropout_prob": 0.1,
"hidden_act": "gelu",
"hidden_dropout_prob": 0.1,
"hidden_size": 768,
"initializer_range": 0.02,
"intermediate_size": 3072,
"layer_norm_eps": 1e-12,
"max_position_embeddings": 512,
"num_attention_heads": 12,
"num_hidden_layers": 12,
"pad_token_id": 0,
"type_vocab_size": 2,
"vocab_size": 21128
},
"uer/roberta-base-chinese-extractive-qa": {
"attention_probs_dropout_prob": 0.1,
"hidden_act": "gelu",
"hidden_dropout_prob": 0.1,
"hidden_size": 768,
"initializer_range": 0.02,
"intermediate_size": 3072,
"layer_norm_eps": 1e-12,
"max_position_embeddings": 512,
"num_attention_heads": 12,
"num_hidden_layers": 12,
"pad_token_id": 0,
"type_vocab_size": 2,
"vocab_size": 21128
}
}


def convert_pytorch_checkpoint_to_paddle(pytorch_src_base_path,
paddle_dump_base_path):
for model_name in convert_model_name_list:
model_state_url = link_template.format(model_name)

paddle_dump_path = os.path.join(paddle_dump_base_path,
model_name.split('/')[-1])

if os.path.exists(
os.path.join(paddle_dump_path, 'model_state.pdparams')):
continue
if not os.path.exists(paddle_dump_path):
os.makedirs(paddle_dump_path)

with open(os.path.join(paddle_dump_path, 'model_config.json'),
'w') as fw:
json.dump(pretrained_init_configuration[model_name], fw)

_ = get_path_from_url(model_state_url, paddle_dump_path)
pytorch_checkpoint_path = os.path.join(paddle_dump_path,
'pytorch_model.bin')
pytorch_state_dict = torch.load(
pytorch_checkpoint_path, map_location="cpu")
paddle_state_dict = OrderedDict()
for k, v in pytorch_state_dict.items():
is_transpose = False
if k[-7:] == ".weight":
if ".embeddings." not in k and ".LayerNorm." not in k:
if v.ndim == 2:
v = v.transpose(0, 1)
is_transpose = True
oldk = k
if k == 'lm_head.bias' and 'lm_head.decoder.bias' in pytorch_state_dict.keys(
):
continue

for huggingface_name, paddle_name in huggingface_to_paddle.items():
k = k.replace(huggingface_name, paddle_name)
if k[:5] == 'bert.':
k = k.replace('bert.', 'roberta.')

print(f"Converting: {oldk} => {k} | is_transpose {is_transpose}")
paddle_state_dict[k] = v.data.numpy()
del pytorch_state_dict

paddle_dump_path = os.path.join(paddle_dump_path,
'model_state.pdparams')
paddle.save(paddle_state_dict, paddle_dump_path)


if __name__ == "__main__":
pytorch_src_base_path = os.path.dirname(os.path.realpath(__file__))
paddle_dump_base_path = pytorch_src_base_path
convert_pytorch_checkpoint_to_paddle(pytorch_src_base_path,
paddle_dump_base_path)
108 changes: 108 additions & 0 deletions community/nosaydomore/deepset_roberta_base_squad2/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
## deepset/roberta-base-squad2

是一个QA_model,在squad2上训练

模型来源:https://huggingface.co/deepset/roberta-base-squad2

使用示例:
```python
from paddlenlp.transformers import (
RobertaModel, RobertaForMaskedLM, RobertaForQuestionAnswering,
RobertaForSequenceClassification, RobertaForTokenClassification)
from paddlenlp.transformers import RobertaBPETokenizer, RobertaTokenizer
import paddle
import os
import numpy as np

def decode(start, end, topk, max_answer_len, undesired_tokens):
"""
Take the output of any :obj:`ModelForQuestionAnswering` and will generate probabilities for each span to be the
actual answer.
"""
# Ensure we have batch axis
if start.ndim == 1:
start = start[None]

if end.ndim == 1:
end = end[None]
# Compute the score of each tuple(start, end) to be the real answer
outer = np.matmul(np.expand_dims(start, -1), np.expand_dims(end, 1))

# Remove candidate with end < start and end - start > max_answer_len
candidates = np.tril(np.triu(outer), max_answer_len - 1)

# Inspired by Chen & al. (https://github.com/facebookresearch/DrQA)
scores_flat = candidates.flatten()
if topk == 1:
idx_sort = [np.argmax(scores_flat)]
elif len(scores_flat) < topk:
idx_sort = np.argsort(-scores_flat)
else:
idx = np.argpartition(-scores_flat, topk)[0:topk]
idx_sort = idx[np.argsort(-scores_flat[idx])]

starts, ends = np.unravel_index(idx_sort, candidates.shape)[1:]
desired_spans = np.isin(starts, undesired_tokens.nonzero()) & np.isin(
ends, undesired_tokens.nonzero())
starts = starts[desired_spans]
ends = ends[desired_spans]
scores = candidates[0, starts, ends]

return starts, ends, scores

tokenizer = RobertaBPETokenizer.from_pretrained('deepset/roberta-base-squad2')
questions = ['Where do I live?']
contexts = ['My name is Sarah and I live in London']

token = tokenizer(
questions,
contexts,
stride=128,
max_seq_len=64,
return_attention_mask=True,
return_special_tokens_mask=True)
# print(token)
special_tokens_mask = token[0]['special_tokens_mask']
count = 3
st_idx = 0
for i in special_tokens_mask:
st_idx += 1
if i == 1:
count -= 1
if count == 0:
break

input_ids = token[0]['input_ids']
offset_mapping = token[0]['offset_mapping']

input_ids = paddle.to_tensor(input_ids, dtype='int64').unsqueeze(0)

model = RobertaForQuestionAnswering.from_pretrained(path)
model.eval()
start, end = model(input_ids=input_ids)
start_ = start[0].numpy()
end_ = end[0].numpy()
undesired_tokens = np.ones_like(input_ids[0].numpy())

undesired_tokens[1:st_idx] = 0
undesired_tokens[-1] = 0

# Generate mask
undesired_tokens_mask = undesired_tokens == 0.0

# Make sure non-context indexes in the tensor cannot contribute to the softmax
start_ = np.where(undesired_tokens_mask, -10000.0, start_)
end_ = np.where(undesired_tokens_mask, -10000.0, end_)

start_ = np.exp(start_ - np.log(
np.sum(np.exp(start_), axis=-1, keepdims=True)))
end_ = np.exp(end_ - np.log(np.sum(np.exp(end_), axis=-1, keepdims=True)))
start_idx, end_idx, score = decode(start_, end_, 1, 64, undesired_tokens)
start_idx, end_idx = offset_mapping[start_idx[0]][0], offset_mapping[
end_idx[0]][1]
print("ans: {}".format(contexts[0][start_idx:end_idx]),
'score:{}'.format(score.item()))

'''
ans: London score:0.7772307395935059
'''
7 changes: 7 additions & 0 deletions community/nosaydomore/deepset_roberta_base_squad2/files.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
{
"model_config_file": "https://paddlenlp.bj.bcebos.com/models/transformers/community/nosaydomore/deepset_roberta_base_squad2/model_config.json",
"model_state": "https://paddlenlp.bj.bcebos.com/models/transformers/community/nosaydomore/deepset_roberta_base_squad2/model_state.pdparams",
"tokenizer_config_file": "https://paddlenlp.bj.bcebos.com/models/transformers/community/nosaydomore/deepset_roberta_base_squad2/tokenizer_config.json",
"vocab_file": "https://paddlenlp.bj.bcebos.com/models/transformers/community/nosaydomore/deepset_roberta_base_squad2/vocab.json",
"merges_file": "https://paddlenlp.bj.bcebos.com/models/transformers/community/nosaydomore/deepset_roberta_base_squad2/merges.txt"
}
57 changes: 57 additions & 0 deletions community/nosaydomore/roberta_en_base/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
## roberta-base
权重来源
https://huggingface.co/roberta-base

在英文数据集上预训练的roberta-base MaskedLM模型

```python
from paddlenlp.transformers import (
RobertaModel, RobertaForMaskedLM, RobertaForQuestionAnswering,
RobertaForSequenceClassification, RobertaForTokenClassification)
from paddlenlp.transformers import RobertaBPETokenizer, RobertaTokenizer
import paddle
import os
import numpy as np

model = RobertaForMaskedLM.from_pretrained('roberta-base')
tokenizer = RobertaBPETokenizer.from_pretrained('roberta-base')
text = ["The man worked as a", "."] #"The man worked as a <mask>."
tokens_list = []
for i in range(2):
tokens_list.append(tokenizer.tokenize(text[i]))

tokens = ['<s>']
tokens.extend(tokens_list[0])
tokens.extend(['<mask>'])
tokens.extend(tokens_list[1])
tokens.extend(['</s>'])
token_ids = tokenizer.convert_tokens_to_ids(tokens)
# print(token_ids)

model.eval()
input_ids = paddle.to_tensor([token_ids])
with paddle.no_grad():
pd_outputs = model(input_ids)

pd_outputs = pd_outputs[0]

pd_outputs_sentence = "paddle: "
for i, id in enumerate(token_ids):
if id == 50264:
scores, index = paddle.nn.functional.softmax(pd_outputs[i],
-1).topk(5)
tokens = tokenizer.convert_ids_to_tokens(index.tolist())
outputs = []
for score, tk in zip(scores.tolist(), tokens):
outputs.append(f"{tk}={score}")
pd_outputs_sentence += "[" + "||".join(outputs) + "]" + " "
else:
pd_outputs_sentence += "".join(
tokenizer.convert_ids_to_tokens(
[id], skip_special_tokens=True)) + " "

print(pd_outputs_sentence)

'''
paddle: The Ġman Ġworked Ġas Ġa [Ġmechanic=0.08702345192432404||Ġwaiter=0.08196478337049484||Ġbutcher=0.07332248240709305||Ġminer=0.046321991831064224||Ġguard=0.040149785578250885] .
'''
yingyibiao marked this conversation as resolved.
Show resolved Hide resolved
7 changes: 7 additions & 0 deletions community/nosaydomore/roberta_en_base/files.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
{
"model_config_file": "https://paddlenlp.bj.bcebos.com/models/transformers/community/nosaydomore/roberta_en_base/model_config.json",
"model_state": "https://paddlenlp.bj.bcebos.com/models/transformers/community/nosaydomore/roberta_en_base/model_state.pdparams",
"tokenizer_config_file": "https://paddlenlp.bj.bcebos.com/models/transformers/community/nosaydomore/roberta_en_base/tokenizer_config.json",
"vocab_file": "https://paddlenlp.bj.bcebos.com/models/transformers/community/nosaydomore/roberta_en_base/vocab.json",
"merges_file": "https://paddlenlp.bj.bcebos.com/models/transformers/community/nosaydomore/roberta_en_base/merges.txt"
}
Loading