-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathreference_data_process.py
161 lines (136 loc) · 5.36 KB
/
reference_data_process.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
import os
import sys
import re
import math
from abc import ABC
import pandas as pd
from sentence_transformers import SentenceTransformer
from milvus import Milvus
from es import Es
# from utils import get_embedding
class ReferenceDataProcess(ABC):
"""
引证数据处理以及导入 ES 引擎
path: 爬取的数据文件(.csv),字段包括['title', 'text', 'date', 'url']
source: 数据来源,如果为 None ,则使用 path 中的数据文件名作为 source
"""
def __init__(
self,
path: str,
embedding_model: SentenceTransformer,
source: str = None
):
name, tail = os.path.splitext(os.path.split(path)[-1])
if tail != '.csv':
raise ValueError('path must be a .csv file.')
data = pd.read_csv(path)
if set(data.columns) != {'title', 'text', 'date', 'url'}:
raise KeyError("Missing field, field must be ['title', 'text', 'date', 'url'].")
self.data = data[['title', 'text', 'date', 'url']]
self.source = source if source else name
self.embedding_model = embedding_model
def _is_contain_chinese(self, string):
"""判断字符串是否不包含中文"""
if len(re.findall('[^\u4e00-\u9fa5]', string)) == len(string):
return False
else:
return True
def _text_split(self, text, size=80):
"""中文文本分句,单句超过 80 的会被切分"""
patterns = [r"""([。!?\?\!])([^”'"])""",
r"""(\.{6})([^”'"])""",
r"""(\…{2})([^”'"])""",
r"""([\.{6}\…{2}。!?\?\!][”'"])([^,。!?\?\!])"""]
for p in patterns:
text = re.sub(p, r"\1\n\2", text)
sentences = []
for s in text.split('\n'):
if s and self._is_contain_chinese(s):
length = len(s)
for i in range(math.ceil(length/size)):
sentences.append(s[i * size:(i+1) * size])
return sentences
def _gene_chunk(self, sentences, chunk_size=250):
"""基于句子列表生成文本块,可控制块的大小"""
cur_chunk = ''
cur_size = 0
chunks = []
for s in sentences:
cur_size += len(s)
cur_chunk += s
if cur_size >= chunk_size * 0.95:
chunks.append(cur_chunk)
cur_chunk = ''
cur_size = 0
if cur_chunk:
chunks.append(cur_chunk)
return chunks
def process_for_milvus(
self,
chunk_size: int = 250,
batch_size: int = 32,
show_progress_bar: bool = None,
device: str = None
):
"""处理数据(字段处理、分块、向量化)以形成可以导入 milvus 的格式"""
# 剔除有缺失值的行,并且将索引重新整理
self.data.dropna(axis=0, inplace=True)
self.data.reset_index(drop=True, inplace=True)
docs = []
all_chunk = []
for i, row in self.data.iterrows():
title, text, date, url = row
title = title.strip()
text = text.strip()
text = re.sub(r'\s+', ' ', text)
date = date.strip() if type(date) == str else date
url = url.strip()
sentences = self._text_split(text)
chunks = self._gene_chunk(sentences, chunk_size=chunk_size)
for chunk in chunks:
document={
"title": title,
"text": chunk,
"url": url,
"source": self.source,
"date": date,
}
docs.append(document)
all_chunk.extend(chunks)
chunk_embeddings = self.embedding_model.encode(
sentences=all_chunk,
batch_size=batch_size,
show_progress_bar=show_progress_bar,
device=device
)
for i in range(len(chunk_embeddings)):
docs[i]['text_vector'] = chunk_embeddings[i]
self.docs_for_milvus = docs
return len(docs)
def process_for_es(self):
# 剔除有缺失值的行,并且将索引重新整理
self.data.dropna(axis=0, inplace=True)
self.data.reset_index(drop=True, inplace=True)
docs = []
for i, row in self.data.iterrows():
title, text, date, url = row
title = title.strip()
text = text.strip()
text = re.sub(r'\s+', ' ', text)
date = date.strip() if type(date) == str else date
url = url.strip()
document={
"text": "[title]" + title + "[title]" + "[text]" + text + "[text]",
"url": url,
"source": self.source,
"date": date,
}
docs.append(document)
self.docs_for_es = docs
return len(docs)
def add_to_milvus(self, client: Milvus):
"""将数据导入 milvus """
return client.add_data(data=self.docs_for_milvus)
def add_to_es(self, es: Es):
"""将数据导入 ES 引擎指定 index"""
return es.add_many_data(documents=self.docs_for_es)