Skip to content

Commit

Permalink
Merge pull request #18 from BrikerMan/develop
Browse files Browse the repository at this point in the history
fix several minimal details
  • Loading branch information
BrikerMan authored Feb 22, 2019
2 parents a94b7ab + 762f4ce commit 03595a4
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 9 deletions.
16 changes: 15 additions & 1 deletion kashgari/embeddings/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ class BaseEmbedding(object):

def __init__(self,
name_or_path: str,
sequence_length: int,
sequence_length: int = None,
embedding_size: int = None,
**kwargs):
"""
Expand Down Expand Up @@ -322,6 +322,20 @@ def prepare_model_input(self, input_x: np.array, **kwargs) -> np.array:


class CustomEmbedding(BaseEmbedding):
def __init__(self,
name_or_path: str = 'custom-embedding',
sequence_length: int = None,
embedding_size: int = None,
**kwargs):
"""
:param name_or_path: just a name for custom embedding
:param sequence_length: length of max sequence, all embedding is shaped as (sequence_length, embedding_size)
:param embedding_size: embedding vector size, only need to set when using a CustomEmbedding
:param kwargs: kwargs to pass to the method, func: `BaseEmbedding.build`
"""
if sequence_length is None or embedding_size is None:
raise ValueError('Must set sequence_length and sequence_length when using the CustomEmbedding layer')
super(CustomEmbedding, self).__init__(name_or_path, sequence_length, embedding_size, **kwargs)

def build(self, **kwargs):
if self._token2idx is None:
Expand Down
10 changes: 5 additions & 5 deletions kashgari/tasks/classification/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,11 +187,11 @@ def _format_output_dic(self, words: List[str], res: np.ndarray):
results = sorted(list(enumerate(res)), key=lambda x: -x[1])
candidates = []
for result in results:
if float(result[1]) > 0.01:
candidates.append({
'name': self.convert_idx_to_label([result[0]])[0],
'confidence': float(result[1]),
})
candidates.append({
'name': self.convert_idx_to_label([result[0]])[0],
'confidence': float(result[1]),
})

data = {
'words': words,
'class': candidates[0],
Expand Down
6 changes: 3 additions & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
"""
import pathlib

from version import __version__
from setuptools import find_packages, setup

# Package meta-data.
Expand All @@ -28,7 +28,7 @@

required = [
'Keras>=2.2.0',
'keras_bert',
'keras-bert==0.25.0',
'h5py>=2.7.1',
'keras-bert==0.25.0',
'scikit-learn>=0.19.1',
Expand All @@ -46,7 +46,7 @@

setup(
name=NAME,
version='0.1.6',
version=__version__,
description=DESCRIPTION,
long_description=README,
long_description_content_type="text/markdown",
Expand Down
14 changes: 14 additions & 0 deletions version.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# encoding: utf-8
"""
@author: BrikerMan
@contact: eliyar917@gmail.com
@blog: https://eliyar.biz
@version: 1.0
@license: Apache Licence
@file: __version__.py
@time: 2019-02-21 15:22
"""

__version__ = '0.1.7'

0 comments on commit 03595a4

Please sign in to comment.