diff --git a/langstream-runtime/langstream-runtime-impl/src/main/python/Pipfile b/langstream-runtime/langstream-runtime-impl/src/main/python/Pipfile index 9e71caead..22aa43711 100644 --- a/langstream-runtime/langstream-runtime-impl/src/main/python/Pipfile +++ b/langstream-runtime/langstream-runtime-impl/src/main/python/Pipfile @@ -20,4 +20,5 @@ tox = "*" black = "*" ruff = "*" pytest = "*" +pytest-asyncio = "*" grpcio-tools = "*" diff --git a/langstream-runtime/langstream-runtime-impl/src/main/python/Pipfile.lock b/langstream-runtime/langstream-runtime-impl/src/main/python/Pipfile.lock index 3bc6cd524..902759e9e 100644 --- a/langstream-runtime/langstream-runtime-impl/src/main/python/Pipfile.lock +++ b/langstream-runtime/langstream-runtime-impl/src/main/python/Pipfile.lock @@ -1,7 +1,7 @@ { "_meta": { "hash": { - "sha256": "96638d9c07c072bb3b3fbb566995ca3355ecf8229f11903c9cb0295684dd267e" + "sha256": "4264f3d81c7bb26bf96d119948d678cd2eac6306a4389679aae2a4be255abf10" }, "pipfile-spec": 6, "requires": {}, @@ -402,6 +402,14 @@ "index": "pypi", "version": "==7.4.3" }, + "pytest-asyncio": { + "hashes": [ + "sha256:40a7eae6dded22c7b604986855ea48400ab15b069ae38116e8c01238e9eeb64d", + "sha256:8666c1c8ac02631d7c51ba282e0c69a8a452b211ffedf2599099845da5c5c37b" + ], + "index": "pypi", + "version": "==0.21.1" + }, "ruff": { "hashes": [ "sha256:171276c1df6c07fa0597fb946139ced1c2978f4f0b8254f201281729981f3c17", @@ -2013,7 +2021,6 @@ "sha256:e529578d017045e2f0ed12d2e00e7e99f780f477234da4aae799ec4afca89f37", "sha256:edd2ffbb789712d83fee19ab009949f998a35c51ad9f9beb39109357416344ff" ], - "markers": "python_version >= '3.8'", "version": "==0.5.1" }, "tqdm": { diff --git a/langstream-runtime/langstream-runtime-impl/src/main/python/langstream_grpc/__main__.py b/langstream-runtime/langstream-runtime-impl/src/main/python/langstream_grpc/__main__.py index e39d6dddf..a6988f314 100644 --- a/langstream-runtime/langstream-runtime-impl/src/main/python/langstream_grpc/__main__.py +++ b/langstream-runtime/langstream-runtime-impl/src/main/python/langstream_grpc/__main__.py @@ -14,11 +14,21 @@ # limitations under the License. # +import asyncio import logging import sys from langstream_grpc.grpc_service import AgentServer + +async def main(target, config, context): + server = AgentServer(target) + await server.init(config, context) + await server.start() + await server.grpc_server.wait_for_termination() + await server.stop() + + if __name__ == "__main__": logging.addLevelName(logging.WARNING, "WARN") logging.basicConfig( @@ -34,7 +44,4 @@ ) sys.exit(1) - server = AgentServer(sys.argv[1], sys.argv[2], sys.argv[3]) - server.start() - server.grpc_server.wait_for_termination() - server.stop() + asyncio.run(main(*sys.argv[1:])) diff --git a/langstream-runtime/langstream-runtime-impl/src/main/python/langstream_grpc/grpc_service.py b/langstream-runtime/langstream-runtime-impl/src/main/python/langstream_grpc/grpc_service.py index 9fca983a1..9f8fd7a20 100644 --- a/langstream-runtime/langstream-runtime-impl/src/main/python/langstream_grpc/grpc_service.py +++ b/langstream-runtime/langstream-runtime-impl/src/main/python/langstream_grpc/grpc_service.py @@ -14,7 +14,7 @@ # limitations under the License. # -import concurrent +import asyncio import importlib import json import os @@ -23,7 +23,7 @@ import threading from concurrent.futures import Future from io import BytesIO -from typing import Iterable, Union, List, Tuple, Any, Optional, Dict +from typing import Union, List, Tuple, Any, Optional, Dict, AsyncIterable import fastavro import grpc @@ -78,30 +78,28 @@ def __init__(self, agent: Union[Agent, Source, Sink, Processor]): self.schemas = {} self.client_schemas = {} - def agent_info(self, _, __): - info = call_method_if_exists(self.agent, "agent_info") or {} + async def agent_info(self, _, __): + info = await acall_method_if_exists(self.agent, "agent_info") or {} return InfoResponse(json_info=json.dumps(info)) - def get_topic_producer_records(self, request_iterator, context): - # TODO: to be implementedbla - for _ in request_iterator: - yield None + async def get_topic_producer_records(self, request_iterator, context): + # TODO: to be implemented + async for _ in request_iterator: + yield - def read(self, requests: Iterable[SourceRequest], _): + async def read(self, requests: AsyncIterable[SourceRequest], _): read_records = {} op_result = [] - read_thread = threading.Thread( - target=self.handle_read_requests, - args=(requests, read_records, op_result), - ) last_record_id = 0 - read_thread.start() + read_requests_task = asyncio.create_task( + self.handle_read_requests(requests, read_records, op_result) + ) while True: if len(op_result) > 0: if op_result[0] is True: break raise op_result[0] - records = self.agent.read() + records = await asyncio.to_thread(self.agent.read) if len(records) > 0: records = [wrap_in_record(record) for record in records] grpc_records = [] @@ -115,25 +113,25 @@ def read(self, requests: Iterable[SourceRequest], _): grpc_records[i].record_id = last_record_id read_records[last_record_id] = record yield SourceResponse(records=grpc_records) - read_thread.join() + read_requests_task.cancel() - def handle_read_requests( + async def handle_read_requests( self, - requests: Iterable[SourceRequest], + requests: AsyncIterable[SourceRequest], read_records: Dict[int, Record], read_result, ): try: - for request in requests: + async for request in requests: if len(request.committed_records) > 0: for record_id in request.committed_records: record = read_records.pop(record_id, None) if record is not None: - call_method_if_exists(self.agent, "commit", record) + await acall_method_if_exists(self.agent, "commit", record) if request.HasField("permanent_failure"): failure = request.permanent_failure record = read_records.pop(failure.record_id, None) - call_method_if_exists( + await acall_method_if_exists( self.agent, "permanent_failure", record, @@ -159,83 +157,49 @@ def handle_requests(handler, requests): pass thread.join() - def process(self, requests: Iterable[ProcessorRequest], _): - return self.handle_requests(self.handle_process_requests, requests) - - def process_record( - self, source_record, get_processed_fn, get_processed_args, process_results - ): - grpc_result = ProcessorResult(record_id=source_record.record_id) - try: - processed_records = get_processed_fn(*get_processed_args) - if isinstance(processed_records, Future): - processed_records.add_done_callback( - lambda f: self.process_record( - source_record, f.result, (), process_results - ) - ) - else: - for record in processed_records: - schemas, grpc_record = self.to_grpc_record(wrap_in_record(record)) - for schema in schemas: - process_results.put(ProcessorResponse(schema=schema)) - grpc_result.records.append(grpc_record) - process_results.put(ProcessorResponse(results=[grpc_result])) - except Exception as e: - grpc_result.error = str(e) - process_results.put(ProcessorResponse(results=[grpc_result])) - - def handle_process_requests( - self, requests: Iterable[ProcessorRequest], process_results - ): - for request in requests: + async def process(self, requests: AsyncIterable[ProcessorRequest], _): + async for request in requests: if request.HasField("schema"): schema = fastavro.parse_schema(json.loads(request.schema.value)) self.client_schemas[request.schema.schema_id] = schema if len(request.records) > 0: for source_record in request.records: - self.process_record( - source_record, - lambda r: self.agent.process(self.from_grpc_record(r)), - (source_record,), - process_results, - ) - process_results.put(True) - - def write(self, requests: Iterable[SinkRequest], _): - return self.handle_requests(self.handle_write_requests, requests) - - def write_record( - self, source_record, get_written_fn, get_written_args, write_results - ): - try: - result = get_written_fn(*get_written_args) - if isinstance(result, Future): - result.add_done_callback( - lambda f: self.write_record( - source_record, f.result, (), write_results - ) - ) - else: - write_results.put(SinkResponse(record_id=source_record.record_id)) - except Exception as e: - write_results.put( - SinkResponse(record_id=source_record.record_id, error=str(e)) - ) - - def handle_write_requests(self, requests: Iterable[SinkRequest], write_results): - for request in requests: + grpc_result = ProcessorResult(record_id=source_record.record_id) + try: + processed_records = await asyncio.to_thread( + self.agent.process, self.from_grpc_record(source_record) + ) + if isinstance(processed_records, Future): + processed_records = await asyncio.wrap_future( + processed_records + ) + for record in processed_records: + schemas, grpc_record = self.to_grpc_record( + wrap_in_record(record) + ) + for schema in schemas: + yield ProcessorResponse(schema=schema) + grpc_result.records.append(grpc_record) + yield ProcessorResponse(results=[grpc_result]) + except Exception as e: + grpc_result.error = str(e) + yield ProcessorResponse(results=[grpc_result]) + + async def write(self, requests: AsyncIterable[SinkRequest], context): + async for request in requests: if request.HasField("schema"): schema = fastavro.parse_schema(json.loads(request.schema.value)) self.client_schemas[request.schema.schema_id] = schema if request.HasField("record"): - self.write_record( - request.record, - lambda r: self.agent.write(self.from_grpc_record(r)), - (request.record,), - write_results, - ) - write_results.put(True) + try: + result = await asyncio.to_thread( + self.agent.write, self.from_grpc_record(request.record) + ) + if isinstance(result, Future): + await asyncio.wrap_future(result) + yield SinkResponse(record_id=request.record.record_id) + except Exception as e: + yield SinkResponse(record_id=request.record.record_id, error=str(e)) def from_grpc_record(self, record: GrpcRecord) -> SimpleRecord: return RecordWithId( @@ -333,6 +297,12 @@ def call_method_if_exists(klass, method, *args, **kwargs): return None +async def acall_method_if_exists(klass, method, *args, **kwargs): + return await asyncio.to_thread( + call_method_if_exists, klass, method, *args, **kwargs + ) + + class MainExecutor(threading.Thread): def __init__(self, onError, klass, method, *args, **kwargs): threading.Thread.__init__(self) @@ -364,17 +334,16 @@ def call_method_new_thread_if_exists(klass, methodName, *args, **kwargs): def crash_process(): logging.error("Main method with an error. Exiting process.") os.exit(1) - return -def init_agent(configuration, context) -> Agent: +async def init_agent(configuration, context) -> Agent: full_class_name = configuration["className"] class_name = full_class_name.split(".")[-1] module_name = full_class_name[: -len(class_name) - 1] module = importlib.import_module(module_name) agent = getattr(module, class_name)() context_impl = DefaultAgentContext(configuration, context) - call_method_if_exists(agent, "init", configuration, context_impl) + await acall_method_if_exists(agent, "init", configuration, context_impl) return agent @@ -388,36 +357,35 @@ def get_persistent_state_directory(self) -> Optional[str]: class AgentServer(object): - def __init__(self, target: str, config: str, context: str): - self.thread_pool = concurrent.futures.ThreadPoolExecutor(max_workers=10) + def __init__(self, target: str): self.target = target - self.grpc_server = grpc.server(self.thread_pool) + self.grpc_server = grpc.aio.server() self.port = self.grpc_server.add_insecure_port(target) + self.agent = None + async def init(self, config, context): configuration = json.loads(config) logging.debug("Configuration: " + json.dumps(configuration)) environment = configuration.get("environment", []) logging.debug("Environment: " + json.dumps(environment)) - for env in environment: key = env["key"] value = env["value"] logging.debug(f"Setting environment variable {key}={value}") os.environ[key] = value + self.agent = await init_agent(configuration, json.loads(context)) - self.agent = init_agent(configuration, json.loads(context)) - - def start(self): - call_method_if_exists(self.agent, "start") + async def start(self): + await acall_method_if_exists(self.agent, "start") call_method_new_thread_if_exists(self.agent, "main", crash_process) agent_pb2_grpc.add_AgentServiceServicer_to_server( AgentService(self.agent), self.grpc_server ) - self.grpc_server.start() + + await self.grpc_server.start() logging.info("GRPC Server started, listening on " + self.target) - def stop(self): - self.grpc_server.stop(None) - call_method_if_exists(self.agent, "close") - self.thread_pool.shutdown(wait=True) + async def stop(self): + await self.grpc_server.stop(None) + await acall_method_if_exists(self.agent, "close") diff --git a/langstream-runtime/langstream-runtime-impl/src/main/python/langstream_grpc/tests/server_and_stub.py b/langstream-runtime/langstream-runtime-impl/src/main/python/langstream_grpc/tests/server_and_stub.py index 77e360780..ea294cbbd 100644 --- a/langstream-runtime/langstream-runtime-impl/src/main/python/langstream_grpc/tests/server_and_stub.py +++ b/langstream-runtime/langstream-runtime-impl/src/main/python/langstream_grpc/tests/server_and_stub.py @@ -29,18 +29,17 @@ def __init__(self, class_name, agent_config={}, context={}): self.config["className"] = class_name self.context = context self.server: Optional[AgentServer] = None - self.channel: Optional[grpc.Channel] = None + self.channel: Optional[grpc.aio.Channel] = None self.stub: Optional[AgentServiceStub] = None - def __enter__(self): - self.server = AgentServer( - "[::]:0", json.dumps(self.config), json.dumps(self.context) - ) - self.server.start() - self.channel = grpc.insecure_channel("localhost:%d" % self.server.port) + async def __aenter__(self): + self.server = AgentServer("[::]:0") + await self.server.init(json.dumps(self.config), json.dumps(self.context)) + await self.server.start() + self.channel = grpc.aio.insecure_channel("localhost:%d" % self.server.port) self.stub = AgentServiceStub(channel=self.channel) return self - def __exit__(self, *args): - self.channel.close() - self.server.stop() + async def __aexit__(self, *args): + await self.channel.close() + await self.server.stop() diff --git a/langstream-runtime/langstream-runtime-impl/src/main/python/langstream_grpc/tests/test_grpc_processor.py b/langstream-runtime/langstream-runtime-impl/src/main/python/langstream_grpc/tests/test_grpc_processor.py index c76c26818..638f05dcc 100644 --- a/langstream-runtime/langstream-runtime-impl/src/main/python/langstream_grpc/tests/test_grpc_processor.py +++ b/langstream-runtime/langstream-runtime-impl/src/main/python/langstream_grpc/tests/test_grpc_processor.py @@ -61,8 +61,8 @@ pytest.param("double_value", "double_value", 42.0, 43.0, 44.0), ], ) -def test_process(input_type, output_type, value, key, header): - with ServerAndStub( +async def test_process(input_type, output_type, value, key, header): + async with ServerAndStub( "langstream_grpc.tests.test_grpc_processor.MyProcessor" ) as server_and_stub: record = GrpcRecord( @@ -79,8 +79,8 @@ def test_process(input_type, output_type, value, key, header): timestamp=43, ) response: ProcessorResponse - for response in server_and_stub.stub.process( - iter([ProcessorRequest(records=[record])]) + async for response in server_and_stub.stub.process( + [ProcessorRequest(records=[record])] ): assert len(response.results) == 1 assert response.results[0].record_id == record.record_id @@ -96,8 +96,8 @@ def test_process(input_type, output_type, value, key, header): assert result.timestamp == record.timestamp -def test_avro(): - with ServerAndStub( +async def test_avro(): + async with ServerAndStub( "langstream_grpc.tests.test_grpc_processor.MyProcessor" ) as server_and_stub: requests = [] @@ -131,7 +131,9 @@ def test_avro(): fp.close() responses: list[ProcessorResponse] - responses = list(server_and_stub.stub.process(iter(requests))) + responses = [ + response async for response in server_and_stub.stub.process(requests) + ] response_schema = responses[0] assert len(response_schema.results) == 0 assert response_schema.HasField("schema") @@ -152,13 +154,13 @@ def test_avro(): fp.close() -def test_empty_record(): - with ServerAndStub( +async def test_empty_record(): + async with ServerAndStub( "langstream_grpc.tests.test_grpc_processor.MyProcessor" ) as server_and_stub: response: ProcessorResponse - for response in server_and_stub.stub.process( - iter([ProcessorRequest(records=[GrpcRecord()])]) + async for response in server_and_stub.stub.process( + [ProcessorRequest(records=[GrpcRecord()])] ): assert len(response.results) == 1 assert response.results[0].record_id == 0 @@ -172,31 +174,25 @@ def test_empty_record(): assert result.HasField("timestamp") is False -def test_failing_record(): - with ServerAndStub( +async def test_failing_record(): + async with ServerAndStub( "langstream_grpc.tests.test_grpc_processor.MyFailingProcessor" ) as server_and_stub: - for response in server_and_stub.stub.process( - iter([ProcessorRequest(records=[GrpcRecord()])]) + async for response in server_and_stub.stub.process( + [ProcessorRequest(records=[GrpcRecord()])] ): assert len(response.results) == 1 assert response.results[0].HasField("error") is True assert response.results[0].error == "failure" -def test_future_record(): - with ServerAndStub( +async def test_future_record(): + async with ServerAndStub( "langstream_grpc.tests.test_grpc_processor.MyFutureProcessor" ) as server_and_stub: response: ProcessorResponse - for response in server_and_stub.stub.process( - iter( - [ - ProcessorRequest( - records=[GrpcRecord(value=Value(string_value="test"))] - ) - ] - ) + async for response in server_and_stub.stub.process( + [ProcessorRequest(records=[GrpcRecord(value=Value(string_value="test"))])] ): assert len(response.results) == 1 assert response.results[0].HasField("error") is False @@ -204,35 +200,35 @@ def test_future_record(): assert response.results[0].records[0].value.string_value == "test" -def test_info(): - with ServerAndStub( +async def test_info(): + async with ServerAndStub( "langstream_grpc.tests.test_grpc_processor.MyProcessor" ) as server_and_stub: - info: InfoResponse = server_and_stub.stub.agent_info(empty_pb2.Empty()) + info: InfoResponse = await server_and_stub.stub.agent_info(empty_pb2.Empty()) assert info.json_info == '{"test-info-key": "test-info-value"}' -def test_init_one_parameter(): - with ServerAndStub( +async def test_init_one_parameter(): + async with ServerAndStub( "langstream_grpc.tests.test_grpc_processor.ProcessorInitOneParameter", {"my-param": "my-value"}, ) as server_and_stub: - for response in server_and_stub.stub.process( - iter([ProcessorRequest(records=[GrpcRecord()])]) + async for response in server_and_stub.stub.process( + [ProcessorRequest(records=[GrpcRecord()])] ): assert len(response.results) == 1 result = response.results[0].records[0] assert result.value.string_value == "my-value" -def test_processor_use_context(): - with ServerAndStub( +async def test_processor_use_context(): + async with ServerAndStub( "langstream_grpc.tests.test_grpc_processor.ProcessorUseContext", {"my-param": "my-value"}, {"persistentStateDirectory": "/tmp/processor"}, ) as server_and_stub: - for response in server_and_stub.stub.process( - iter([ProcessorRequest(records=[GrpcRecord()])]) + async for response in server_and_stub.stub.process( + [ProcessorRequest(records=[GrpcRecord()])] ): assert len(response.results) == 1 result = response.results[0].records[0] diff --git a/langstream-runtime/langstream-runtime-impl/src/main/python/langstream_grpc/tests/test_grpc_sink.py b/langstream-runtime/langstream-runtime-impl/src/main/python/langstream_grpc/tests/test_grpc_sink.py index 9956d8b3f..0d882a2e4 100644 --- a/langstream-runtime/langstream-runtime-impl/src/main/python/langstream_grpc/tests/test_grpc_sink.py +++ b/langstream-runtime/langstream-runtime-impl/src/main/python/langstream_grpc/tests/test_grpc_sink.py @@ -30,12 +30,12 @@ from langstream_grpc.tests.server_and_stub import ServerAndStub -def test_write(): - with ServerAndStub( +async def test_write(): + async with ServerAndStub( "langstream_grpc.tests.test_grpc_sink.MySink" ) as server_and_stub: - def requests(): + async def requests(): schema = { "type": "record", "name": "Test", @@ -60,7 +60,10 @@ def requests(): fp.close() responses: list[SinkResponse] - responses = list(server_and_stub.stub.write(iter(requests()))) + responses = [ + response async for response in server_and_stub.stub.write(requests()) + ] + assert len(responses) == 1 assert responses[0].record_id == 43 assert len(server_and_stub.server.agent.written_records) == 1 @@ -70,47 +73,45 @@ def requests(): ) -def test_write_error(): - with ServerAndStub( +async def test_write_error(): + async with ServerAndStub( "langstream_grpc.tests.test_grpc_sink.MyErrorSink" ) as server_and_stub: responses: list[SinkResponse] - responses = list( - server_and_stub.stub.write( - iter( - [ - SinkRequest( - record=GrpcRecord( - value=Value(string_value="test"), - ) + responses = [ + response + async for response in server_and_stub.stub.write( + [ + SinkRequest( + record=GrpcRecord( + value=Value(string_value="test"), ) - ] - ) + ) + ] ) - ) + ] assert len(responses) == 1 assert responses[0].error == "test-error" -def test_write_future(): - with ServerAndStub( +async def test_write_future(): + async with ServerAndStub( "langstream_grpc.tests.test_grpc_sink.MyFutureSink" ) as server_and_stub: responses: list[SinkResponse] - responses = list( - server_and_stub.stub.write( - iter( - [ - SinkRequest( - record=GrpcRecord( - record_id=42, - value=Value(string_value="test"), - ) + responses = [ + response + async for response in server_and_stub.stub.write( + [ + SinkRequest( + record=GrpcRecord( + record_id=42, + value=Value(string_value="test"), ) - ] - ) + ) + ] ) - ) + ] assert len(responses) == 1 assert responses[0].record_id == 42 assert len(server_and_stub.server.agent.written_records) == 1 diff --git a/langstream-runtime/langstream-runtime-impl/src/main/python/langstream_grpc/tests/test_grpc_source.py b/langstream-runtime/langstream-runtime-impl/src/main/python/langstream_grpc/tests/test_grpc_source.py index 64e2d39ec..3f019ea31 100644 --- a/langstream-runtime/langstream-runtime-impl/src/main/python/langstream_grpc/tests/test_grpc_source.py +++ b/langstream-runtime/langstream-runtime-impl/src/main/python/langstream_grpc/tests/test_grpc_source.py @@ -14,9 +14,8 @@ # limitations under the License. # +import asyncio import json -import queue -import time from io import BytesIO from typing import List @@ -35,24 +34,24 @@ @pytest.fixture -def server_and_stub(): - with ServerAndStub( +async def server_and_stub(): + async with ServerAndStub( "langstream_grpc.tests.test_grpc_source.MySource" ) as server_and_stub: yield server_and_stub -def test_read(server_and_stub): +async def test_read(server_and_stub): stop = False - def requests(): + async def requests(): while not stop: - time.sleep(0.1) - yield from () + await asyncio.sleep(0.1) + yield responses: list[SourceResponse] = [] i = 0 - for response in server_and_stub.stub.read(iter(requests())): + async for response in server_and_stub.stub.read(requests()): responses.append(response) i += 1 stop = i == 4 @@ -92,24 +91,21 @@ def requests(): assert record.value.long_value == 43 -def test_commit(server_and_stub): - to_commit = queue.Queue() +async def test_commit(server_and_stub): + to_commit = asyncio.Queue() - def send_commit(): + async def send_commit(): committed = 0 while committed < 3: - try: - commit_id = to_commit.get(True, 1) - yield SourceRequest(committed_records=[commit_id]) - committed += 1 - except queue.Empty: - pass + commit_id = await to_commit.get() + yield SourceRequest(committed_records=[commit_id]) + committed += 1 with pytest.raises(grpc.RpcError): response: SourceResponse - for response in server_and_stub.stub.read(iter(send_commit())): + async for response in server_and_stub.stub.read(send_commit()): for record in response.records: - to_commit.put(record.record_id) + await to_commit.put(record.record_id) sent = server_and_stub.server.agent.sent committed = server_and_stub.server.agent.committed @@ -118,24 +114,21 @@ def send_commit(): assert committed[1].value() == sent[1]["value"] -def test_permanent_failure(server_and_stub): - to_fail = queue.Queue() +async def test_permanent_failure(server_and_stub): + to_fail = asyncio.Queue() - def send_failure(): - try: - record_id = to_fail.get(True) - yield SourceRequest( - permanent_failure=PermanentFailure( - record_id=record_id, error_message="failure" - ) + async def send_failure(): + record_id = await to_fail.get() + yield SourceRequest( + permanent_failure=PermanentFailure( + record_id=record_id, error_message="failure" ) - except queue.Empty: - pass + ) response: SourceResponse - for response in server_and_stub.stub.read(iter(send_failure())): + async for response in server_and_stub.stub.read(send_failure()): for record in response.records: - to_fail.put(record.record_id) + await to_fail.put(record.record_id) failures = server_and_stub.server.agent.failures assert len(failures) == 1 diff --git a/langstream-runtime/langstream-runtime-impl/src/main/python/pyproject.toml b/langstream-runtime/langstream-runtime-impl/src/main/python/pyproject.toml index 45815d901..089786a9f 100644 --- a/langstream-runtime/langstream-runtime-impl/src/main/python/pyproject.toml +++ b/langstream-runtime/langstream-runtime-impl/src/main/python/pyproject.toml @@ -65,3 +65,6 @@ exclude = [ "proto", ] +[tool.pytest.ini_options] +minversion = "6.0" +asyncio_mode = "auto" diff --git a/langstream-runtime/langstream-runtime-impl/src/main/python/tox.ini b/langstream-runtime/langstream-runtime-impl/src/main/python/tox.ini index 1939db686..91c3155b3 100644 --- a/langstream-runtime/langstream-runtime-impl/src/main/python/tox.ini +++ b/langstream-runtime/langstream-runtime-impl/src/main/python/tox.ini @@ -9,6 +9,7 @@ description = run unit tests deps = pipenv pytest + pytest-asyncio commands = pipenv sync pytest {posargs:langstream_grpc}