Skip to content

Commit

Permalink
Merge pull request #4 from BrikerMan/develop
Browse files Browse the repository at this point in the history
release v0.1.6
  • Loading branch information
BrikerMan authored Feb 4, 2019
2 parents ab88572 + c0ab081 commit 5d865af
Show file tree
Hide file tree
Showing 9 changed files with 119 additions and 21 deletions.
1 change: 1 addition & 0 deletions .coveragerc
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ exclude_lines =
# Don't complain about missing debug-only code:
def __repr__
if self\.debug
if debug_info:

# Don't complain if tests don't hit defensive assertion code:
raise AssertionError
Expand Down
3 changes: 3 additions & 0 deletions kashgari/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@
import kashgari.corpus
import kashgari.tasks

from kashgari.tasks import classification
from kashgari.tasks import seq_labeling


if __name__ == "__main__":
print("Hello world")
5 changes: 3 additions & 2 deletions kashgari/corpus.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,8 +177,9 @@ class ChinaPeoplesDailyNerCorpus(object):
__corpus_name__ = 'corpus/china-people-daily-ner-corpus'
__zip_file__name = 'corpus/china-people-daily-ner-corpus.tar.gz'

__desc__ = """ Download from NLPCC 2018 Task4 dataset
"""
__desc__ = """
https://github.com/zjy-ucas/ChineseNER/
"""

@classmethod
def get_sequence_tagging_data(cls,
Expand Down
62 changes: 55 additions & 7 deletions kashgari/tasks/classification/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def get_data_generator(self,
batch_size: int = 64,
is_bert: bool = False):
while True:
page_list = list(range(len(x_data) // batch_size + 1))
page_list = list(range((len(x_data) // batch_size) + 1))
random.shuffle(page_list)
for page in page_list:
start_index = page * batch_size
Expand Down Expand Up @@ -183,7 +183,35 @@ def fit(self,
class_weight=class_weights,
**fit_kwargs)

def predict(self, sentence: Union[List[str], List[List[str]]], batch_size=None):
def _format_output_dic(self, words: List[str], res: np.ndarray):
results = sorted(list(enumerate(res)), key=lambda x: -x[1])
candidates = []
for result in results:
if float(result[1]) > 0.01:
candidates.append({
'name': self.convert_idx_to_label([result[0]])[0],
'confidence': float(result[1]),
})
data = {
'words': words,
'class': candidates[0],
'class_candidates': candidates
}
return data

def predict(self,
sentence: Union[List[str], List[List[str]]],
batch_size=None,
output_dict=False,
debug_info=False) -> Union[List[str], str, List[Dict], Dict]:
"""
predict with model
:param sentence: single sentence as List[str] or list of sentence as List[List[str]]
:param batch_size: predict batch_size
:param output_dict: return dict with result with confidence
:param debug_info: print debug info using logging.debug when True
:return:
"""
tokens = self.embedding.tokenize(sentence)
is_list = not isinstance(sentence[0], str)
if is_list:
Expand All @@ -198,12 +226,32 @@ def predict(self, sentence: Union[List[str], List[List[str]]], batch_size=None):
x = [padded_tokens, np.zeros(shape=(len(padded_tokens), self.embedding.sequence_length))]
else:
x = padded_tokens
predict_result = self.model.predict(x, batch_size=batch_size).argmax(-1)
labels = self.convert_idx_to_label(predict_result)
if is_list:
return labels
res = self.model.predict(x, batch_size=batch_size)
predict_result = res.argmax(-1)

if debug_info:
logging.info('input: {}'.format(x))
logging.info('output: {}'.format(res))
logging.info('output argmax: {}'.format(predict_result))

if output_dict:
if is_list:
words_list: List[List[str]] = sentence
else:
words_list: List[List[str]] = [sentence]
results = []
for index in range(len(words_list)):
results.append(self._format_output_dic(words_list[index], res[index]))
if is_list:
return results
else:
return results[0]
else:
return labels[0]
results = self.convert_idx_to_label(predict_result)
if is_list:
return results
else:
return results[0]

def evaluate(self, x_data, y_data, batch_size=None, digits=4, debug_info=False) -> Tuple[float, float, Dict]:
y_pred = self.predict(x_data, batch_size=batch_size)
Expand Down
2 changes: 2 additions & 0 deletions kashgari/tasks/seq_labeling/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
from .blstm_model import BLSTMModel
from .blstm_crf_model import BLSTMCRFModel
from .cnn_lstm_model import CNNLSTMModel
from .base_model import SequenceLabelingModel


if __name__ == '__main__':
print("hello, world")
63 changes: 52 additions & 11 deletions kashgari/tasks/seq_labeling/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from keras.preprocessing import sequence
from keras.utils import to_categorical
from seqeval.metrics import classification_report
from seqeval.metrics.sequence_labeling import get_entities

import kashgari.macros as k
from kashgari.utils import helper
Expand Down Expand Up @@ -126,7 +127,7 @@ def get_data_generator(self,
batch_size: int = 64):
is_bert = self.embedding.embedding_type == 'bert'
while True:
page_list = list(range(len(x_data) // batch_size + 1))
page_list = list(range((len(x_data) // batch_size) + 1))
random.shuffle(page_list)
for page in page_list:
start_index = page * batch_size
Expand Down Expand Up @@ -234,16 +235,38 @@ def fit(self,
epochs=epochs,
**fit_kwargs)

def _format_output_dic(self, words: List[str], tags: List[str], chunk_joiner: str):
chunks = get_entities(tags)
res = {
'words': words,
'entities': []
}
for chunk_type, chunk_start, chunk_end in chunks:
chunk_end += 1
entity = {
'text': chunk_joiner.join(words[chunk_start: chunk_end]),
'type': chunk_type,
# 'score': float(np.average(prob[chunk_start: chunk_end])),
'beginOffset': chunk_start,
'endOffset': chunk_end
}
res['entities'].append(entity)
return res

def predict(self,
sentence: Union[List[str], List[List[str]]],
batch_size=None,
convert_to_labels=True):
output_dict=False,
chunk_joiner=' ',
debug_info=False):
"""
predict with model
:param sentence: input for predict, accept a single sentence as type List[str] or
list of sentence as List[List[str]]
:param batch_size: predict batch_size
:param convert_to_labels: if True, return labels or return label idxs
:param output_dict: return dict with result with confidence
:param chunk_joiner: the char to join the chunks when output dict
:param debug_info: print debug info using logging.debug when True
:return:
"""
tokens = self.embedding.tokenize(sentence)
Expand All @@ -262,16 +285,34 @@ def predict(self,
x = [padded_tokens, np.zeros(shape=(len(padded_tokens), self.embedding.sequence_length))]
else:
x = padded_tokens
predict_result = self.model.predict(x, batch_size=batch_size).argmax(-1)
if convert_to_labels:
result = self.convert_idx_to_labels(predict_result, seq_length)
else:
result = predict_result

if is_list:
return result
predict_result_prob = self.model.predict(x, batch_size=batch_size)
predict_result = predict_result_prob.argmax(-1)
if debug_info:
logging.info('input: {}'.format(x))
logging.info('output: {}'.format(predict_result_prob))
logging.info('output argmax: {}'.format(predict_result))

result: List[List[str]] = self.convert_idx_to_labels(predict_result, seq_length)
if output_dict:
dict_list = []
if is_list:
sentence_list: List[List[str]] = sentence
else:
sentence_list: List[List[str]] = [sentence]
for index in range(len(sentence_list)):
dict_list.append(self._format_output_dic(sentence_list[index],
result[index],
chunk_joiner))
if is_list:
return dict_list
else:
return dict_list[0]
else:
return result[0]
if is_list:
return result
else:
return result[0]

def evaluate(self, x_data, y_data, batch_size=None, digits=4, debug_info=False) -> Tuple[float, float, Dict]:
seq_length = [len(x) for x in x_data]
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@

setup(
name=NAME,
version='0.1.5',
version='0.1.6',
description=DESCRIPTION,
long_description=README,
long_description_content_type="text/markdown",
Expand Down
1 change: 1 addition & 0 deletions tests/test_classifier_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ def test_predict(self):
assert isinstance(self.model.predict(sentence), str)
assert isinstance(self.model.predict([sentence]), list)
logging.info('test predict: {} -> {}'.format(sentence, self.model.predict(sentence)))
self.model.predict(sentence, output_dict=True)

def test_eval(self):
self.test_fit()
Expand Down
1 change: 1 addition & 0 deletions tests/test_seq_labeling_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ def test_predict(self):
self.assertTrue(isinstance(self.model.predict(sentence)[0], str))
self.assertTrue(isinstance(self.model.predict([sentence])[0], list))
self.assertEqual(len(self.model.predict(sentence)), len(sentence))
self.model.predict(sentence, output_dict=True)

def test_eval(self):
self.test_fit()
Expand Down

0 comments on commit 5d865af

Please sign in to comment.