diff --git a/gnes/cli/api.py b/gnes/cli/api.py index 8990008c..6528c9aa 100644 --- a/gnes/cli/api.py +++ b/gnes/cli/api.py @@ -58,8 +58,6 @@ def client(args): return _client_http(args) elif args.client == 'cli': return _client_cli(args) - elif args.client == 'benchmark': - return _client_bm(args) else: raise ValueError('gnes client must follow with a client type from {http, cli, benchmark...}\n' 'see "gnes client --help" for details') @@ -94,11 +92,6 @@ def _client_cli(args): CLIClient(args) -def _client_bm(args): - from ..client.benchmark import BenchmarkClient - BenchmarkClient(args) - - def compose(args): from ..composer.base import YamlComposer from ..composer.flask import YamlComposerFlask diff --git a/gnes/cli/parser.py b/gnes/cli/parser.py index ff8923fa..764c4ee6 100644 --- a/gnes/cli/parser.py +++ b/gnes/cli/parser.py @@ -365,21 +365,6 @@ def set_client_cli_parser(parser=None): return parser -def set_client_benchmark_parser(parser=None): - if not parser: - parser = set_base_parser() - _set_grpc_parser(parser) - parser.add_argument('--batch_size', type=int, default=64, - help='the size of the request to split') - parser.add_argument('--request_length', type=int, - default=1024, - help='binary string length of each request') - parser.add_argument('--num_requests', type=int, - default=128, - help='number of total requests') - return parser - - def set_client_http_parser(parser=None): if not parser: parser = set_base_parser() @@ -422,8 +407,6 @@ def get_main_parser(): set_client_http_parser( spp.add_parser('http', help='start a client that allows HTTP requests as input', formatter_class=adf)) set_client_cli_parser(spp.add_parser('cli', help='start a client that allows stdin as input', formatter_class=adf)) - set_client_benchmark_parser( - spp.add_parser('benchmark', help='start a client for benchmark and unittest', formatter_class=adf)) # others set_composer_flask_parser( diff --git a/gnes/client/benchmark.py b/gnes/client/benchmark.py deleted file mode 100644 index 09265c66..00000000 --- a/gnes/client/benchmark.py +++ /dev/null @@ -1,52 +0,0 @@ -# Tencent is pleased to support the open source community by making GNES available. -# -# Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved. -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import grpc - -from ..helper import TimeContext -from ..proto import gnes_pb2_grpc, RequestGenerator - - -class BenchmarkClient: - def __init__(self, args): - - all_bytes = [b'a' * args.request_length] * args.num_requests - - with grpc.insecure_channel( - '%s:%d' % (args.grpc_host, args.grpc_port), - options=[('grpc.max_send_message_length', args.max_message_size), - ('grpc.max_receive_message_length', args.max_message_size)]) as channel: - stub = gnes_pb2_grpc.GnesRPCStub(channel) - - id = 0 - with TimeContext('StreamCall') as tc: - resp = stub.StreamCall(RequestGenerator.index(all_bytes, args.batch_size)) - for r in resp: - assert r.request_id == id - id += 1 - stream_call_el = tc.duration - - with TimeContext('Call') as tc: - for req in RequestGenerator.index(all_bytes, batch_size=1): - r = stub.Call(req) - assert r.request_id == req.request_id - - call_el = tc.duration - - print('num_requests %d\n' - 'request_length %d' % (args.num_requests, args.request_length)) - print('StreamCall %3.3f s\n' - 'Call %3.3f s\n' % (stream_call_el, call_el)) diff --git a/gnes/encoder/base.py b/gnes/encoder/base.py index cff40538..cbd0edfa 100644 --- a/gnes/encoder/base.py +++ b/gnes/encoder/base.py @@ -49,6 +49,7 @@ def encode(self, text: List[str], *args, **kwargs) -> Union[Tuple, np.ndarray]: class BaseNumericEncoder(BaseEncoder): + """Note that all NumericEncoder can not be used as the first encoder of the pipeline""" def encode(self, data: np.ndarray, *args, **kwargs) -> np.ndarray: pass diff --git a/gnes/encoder/text/flair.py b/gnes/encoder/text/flair.py index 8355017b..52fece88 100644 --- a/gnes/encoder/text/flair.py +++ b/gnes/encoder/text/flair.py @@ -14,7 +14,7 @@ # limitations under the License. -from typing import List +from typing import List, Tuple import numpy as np @@ -25,16 +25,22 @@ class FlairEncoder(BaseTextEncoder): is_trained = True - def __init__(self, pooling_strategy: str = 'mean', *args, **kwargs): + def __init__(self, + word_embedding: str = 'glove', + flair_embeddings: Tuple[str] = ('news-forward', 'news-backward'), + pooling_strategy: str = 'mean', *args, **kwargs): super().__init__(*args, **kwargs) + + self.word_embedding = word_embedding + self.flair_embeddings = flair_embeddings self.pooling_strategy = pooling_strategy def post_init(self): from flair.embeddings import DocumentPoolEmbeddings, WordEmbeddings, FlairEmbeddings self._flair = DocumentPoolEmbeddings( - [WordEmbeddings('glove'), - FlairEmbeddings('news-forward'), - FlairEmbeddings('news-backward')], + [WordEmbeddings(self.word_embedding), + FlairEmbeddings(self.flair_embeddings[0]), + FlairEmbeddings(self.flair_embeddings[1])], pooling=self.pooling_strategy) @batching diff --git a/tests/test_flair_encoder.py b/tests/test_flair_encoder.py index 6a4aef86..798c2eb6 100644 --- a/tests/test_flair_encoder.py +++ b/tests/test_flair_encoder.py @@ -17,15 +17,14 @@ def setUp(self): if line: self.test_str.append(line) - self.flair_encoder = FlairEncoder( - model_name=os.environ.get('FLAIR_CI_MODEL'), - pooling_strategy="REDUCE_MEAN") + self.flair_encoder = FlairEncoder(model_name=os.environ.get('FLAIR_CI_MODEL')) @unittest.SkipTest def test_encoding(self): - vec = self.flair_encoder.encode(self.test_str) - self.assertEqual(vec.shape[0], len(self.test_str)) - self.assertEqual(vec.shape[1], 512) + vec = self.flair_encoder.encode(self.test_str[:2]) + print(vec.shape) + self.assertEqual(vec.shape[0], 2) + self.assertEqual(vec.shape[1], 4196) @unittest.SkipTest def test_dump_load(self): diff --git a/tests/test_stream_grpc.py b/tests/test_stream_grpc.py index 38fed2ed..5ce078e6 100644 --- a/tests/test_stream_grpc.py +++ b/tests/test_stream_grpc.py @@ -4,8 +4,7 @@ import grpc -from gnes.cli.parser import set_frontend_parser, set_router_parser, set_client_benchmark_parser -from gnes.client.benchmark import BenchmarkClient +from gnes.cli.parser import set_frontend_parser, set_router_parser from gnes.helper import TimeContext from gnes.proto import RequestGenerator, gnes_pb2_grpc from gnes.service.base import SocketType, MessageHandler, BaseService as BS @@ -55,13 +54,6 @@ def test_bm_frontend(self): '--yaml_path', 'BaseRouter' ]) - b_args = set_client_benchmark_parser().parse_args([ - '--num_requests', '10', - '--request_length', '65536' - ]) - with RouterService(p_args), FrontendService(args): - BenchmarkClient(b_args) - def test_grpc_frontend(self): args = set_frontend_parser().parse_args([ '--grpc_host', '127.0.0.1',