diff --git a/jetstream/core/tools/requester.py b/jetstream/core/tools/requester.py index fb464c38..d99d0ee9 100644 --- a/jetstream/core/tools/requester.py +++ b/jetstream/core/tools/requester.py @@ -29,7 +29,7 @@ _SESSION_CACHE = flags.DEFINE_string( 'session_cache', '', 'Location of any pre-cached results' ) -_TEXT = flags.DEFINE_string('text', 'AB', 'The message') +_TEXT = flags.DEFINE_string('text', 'Today is a good day', 'The message') _PRIORITY = flags.DEFINE_integer('priority', 0, 'Message priority') _MAX_TOKENS = flags.DEFINE_integer('max_tokens', 3, 'Maximum number of tokens') @@ -41,22 +41,23 @@ def _GetResponseAsync( """Gets an async response.""" response = stub.Decode(request) + output = "" for token_list in response: - print(token_list.response[0], end='', flush=True) - print('\n') + output += token_list.response[0] + print(f'Prompt: {_TEXT.value}') + print(f'Response: {output}') def main(argv: Sequence[str]) -> None: del argv # Note: Uses insecure_channel only for local testing. Please add grpc credentials for Production. address = f'{_SERVER.value}:{_PORT.value}' - print(address) with grpc.insecure_channel( address ) as channel: grpc.channel_ready_future(channel).result() stub = jetstream_pb2_grpc.OrchestratorStub(channel) - print('Making request') + print(f'Sending request to: {address}') request = jetstream_pb2.DecodeRequest( session_cache=_SESSION_CACHE.value, additional_text=_TEXT.value,