-
Notifications
You must be signed in to change notification settings - Fork 25
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Showing
216 changed files
with
11,125 additions
and
3,826 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
0
...blocks/metrics/classification/__init__.py → docs/__init__.py
100755 → 100644
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
$ git clone https://github.com/NVIDIA/apex | ||
$ sed -i "s/or (bare_metal_minor != torch_binary_minor)//g" apex/setup.py | ||
$ pip install -v --disable-pip-version-check --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" apex/ | ||
$ rm -rf apex |
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
from torchblocks.utils import json_to_text | ||
from sklearn.model_selection import StratifiedKFold | ||
|
||
|
||
def get_data(data_path, datatype): | ||
data = [] | ||
if datatype == 'train': | ||
with open(data_path) as f: | ||
for i in f: | ||
dict_txt = eval(i) | ||
if dict_txt['query'] == '': | ||
continue | ||
for j in dict_txt['candidate']: | ||
if j['text'] == '': | ||
continue | ||
data.append({'query': dict_txt['query'], 'candidate': j['text'], 'label': j['label']}) | ||
else: | ||
with open(data_path) as f: | ||
for i in f: | ||
dict_txt = eval(i) | ||
for j in dict_txt['candidate']: | ||
data.append({'text_id': dict_txt['text_id'], 'query': dict_txt['query'], 'candidate': j['text']}) | ||
return data | ||
|
||
|
||
def generate_data(train_data, random_state=42): | ||
X = range(len(train_data)) | ||
y = [x['label'] for x in train_data] | ||
skf = StratifiedKFold(n_splits=3, shuffle=True, random_state=random_state) | ||
for fold, (train_index, dev_index) in enumerate(skf.split(X, y)): | ||
tmp_train_df = [train_data[index] for index in train_index] | ||
tmp_dev_df = [train_data[index] for index in dev_index] | ||
json_to_text(f'../dataset/ccks2021/ccks2021_train_seed{random_state}_fold{fold}.json', tmp_train_df) | ||
json_to_text(f'../dataset/ccks2021/ccks2021_dev_seed{random_state}_fold{fold}.json', tmp_dev_df) | ||
|
||
|
||
if __name__ == '__main__': | ||
seed = 42 | ||
train_path1 = '../dataset/ccks2021/round1_train.txt' | ||
train_path2 = '../dataset/ccks2021/round2_train.txt' | ||
train_data1 = get_data(train_path1, 'train') | ||
train_data2 = get_data(train_path2, 'train') | ||
train_data = train_data1 | ||
train_data.extend(train_data2) | ||
generate_data(train_data, 42) | ||
generate_data(train_data, 24) | ||
generate_data(train_data, 33) | ||
print('...............kf finish...........') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
from torchblocks.utils import json_to_text | ||
from torchblocks.tasks import get_spans_from_bio_tags | ||
from torchblocks.data.splits import split_ner_stratified_kfold | ||
|
||
''' | ||
采用多标签方式进行划分数据 | ||
''' | ||
|
||
train_file = '../dataset/cner/train.char.bmes' | ||
dev_file = '../dataset/cner/dev.char.bmes' | ||
folds = 5 | ||
sentences = [] | ||
lines = [] | ||
for input_file in [train_file, dev_file]: | ||
with open(input_file, 'r') as f: | ||
words, labels = [], [] | ||
for line in f: | ||
if line.startswith("-DOCSTART-") or line == "" or line == "\n": | ||
if words: | ||
lines.append([words, labels]) | ||
words, labels = [], [] | ||
else: | ||
splits = line.split(" ") | ||
words.append(splits[0]) | ||
if len(splits) > 1: | ||
label = splits[-1].replace("\n", "") | ||
if 'M-' in label: | ||
label = label.replace('M-', 'I-') | ||
elif 'E-' in label: | ||
label = label.replace('E-', 'I-') | ||
elif 'S-' in label: # 去除S标签,主要方便后面做实验 | ||
label = "O" | ||
labels.append(label) | ||
else: | ||
labels.append("O") | ||
if words: | ||
lines.append([words, labels]) | ||
|
||
for i, (words, labels) in enumerate(lines): | ||
spans = get_spans_from_bio_tags(labels, id2label=None) | ||
new_spans = [] | ||
for span in spans: | ||
tag, start, end = span | ||
new_spans.append([tag, start, end + 1, "".join(words[start:(end + 1)])]) | ||
sentence = {'id': i, 'text': words, 'entities': new_spans, 'bio_seq': labels} | ||
sentences.append(sentence) | ||
|
||
entities_list = [x['entities'] for x in sentences] | ||
all_indices = split_ner_stratified_kfold(entities_list, num_folds=5) | ||
for fold, (train_indices, val_indices) in enumerate(all_indices): | ||
print("The number of train examples: ",len(train_indices)) | ||
print("The number of dev examples: ", len(val_indices)) | ||
train_data = [sentences[i] for i in train_indices] | ||
dev_data = [sentences[i] for i in val_indices] | ||
json_to_text(f'../dataset/cner/cner_train_fold{fold}.json', train_data) | ||
json_to_text(f'../dataset/cner/cner_dev_fold{fold}.json', dev_data) | ||
|
Oops, something went wrong.