diff --git a/kashgari/embeddings/embeddings.py b/kashgari/embeddings/embeddings.py index 3e836855..6dc943f2 100644 --- a/kashgari/embeddings/embeddings.py +++ b/kashgari/embeddings/embeddings.py @@ -46,7 +46,7 @@ class BaseEmbedding(object): def __init__(self, name_or_path: str, - sequence_length: int, + sequence_length: int = None, embedding_size: int = None, **kwargs): """ @@ -322,6 +322,20 @@ def prepare_model_input(self, input_x: np.array, **kwargs) -> np.array: class CustomEmbedding(BaseEmbedding): + def __init__(self, + name_or_path: str = 'custom-embedding', + sequence_length: int = None, + embedding_size: int = None, + **kwargs): + """ + :param name_or_path: just a name for custom embedding + :param sequence_length: length of max sequence, all embedding is shaped as (sequence_length, embedding_size) + :param embedding_size: embedding vector size, only need to set when using a CustomEmbedding + :param kwargs: kwargs to pass to the method, func: `BaseEmbedding.build` + """ + if sequence_length is None or embedding_size is None: + raise ValueError('Must set sequence_length and sequence_length when using the CustomEmbedding layer') + super(CustomEmbedding, self).__init__(name_or_path, sequence_length, embedding_size, **kwargs) def build(self, **kwargs): if self._token2idx is None: diff --git a/kashgari/tasks/classification/base_model.py b/kashgari/tasks/classification/base_model.py index 76727a02..215d49c3 100644 --- a/kashgari/tasks/classification/base_model.py +++ b/kashgari/tasks/classification/base_model.py @@ -187,11 +187,11 @@ def _format_output_dic(self, words: List[str], res: np.ndarray): results = sorted(list(enumerate(res)), key=lambda x: -x[1]) candidates = [] for result in results: - if float(result[1]) > 0.01: - candidates.append({ - 'name': self.convert_idx_to_label([result[0]])[0], - 'confidence': float(result[1]), - }) + candidates.append({ + 'name': self.convert_idx_to_label([result[0]])[0], + 'confidence': float(result[1]), + }) + data = { 'words': words, 'class': candidates[0], diff --git a/setup.py b/setup.py index e5124e9e..53723ad0 100644 --- a/setup.py +++ b/setup.py @@ -12,7 +12,7 @@ """ import pathlib - +from version import __version__ from setuptools import find_packages, setup # Package meta-data. @@ -28,7 +28,7 @@ required = [ 'Keras>=2.2.0', - 'keras_bert', + 'keras-bert==0.25.0', 'h5py>=2.7.1', 'keras-bert==0.25.0', 'scikit-learn>=0.19.1', @@ -46,7 +46,7 @@ setup( name=NAME, - version='0.1.6', + version=__version__, description=DESCRIPTION, long_description=README, long_description_content_type="text/markdown", diff --git a/version.py b/version.py new file mode 100644 index 00000000..a9ff152f --- /dev/null +++ b/version.py @@ -0,0 +1,14 @@ +# encoding: utf-8 +""" +@author: BrikerMan +@contact: eliyar917@gmail.com +@blog: https://eliyar.biz + +@version: 1.0 +@license: Apache Licence +@file: __version__.py +@time: 2019-02-21 15:22 + +""" + +__version__ = '0.1.7'