Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use Python asyncio gRPC #722

Merged
merged 1 commit into from
Nov 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,5 @@ tox = "*"
black = "*"
ruff = "*"
pytest = "*"
pytest-asyncio = "*"
grpcio-tools = "*"

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,21 @@
# limitations under the License.
#

import asyncio
import logging
import sys

from langstream_grpc.grpc_service import AgentServer


async def main(target, config, context):
server = AgentServer(target)
await server.init(config, context)
await server.start()
await server.grpc_server.wait_for_termination()
await server.stop()


if __name__ == "__main__":
logging.addLevelName(logging.WARNING, "WARN")
logging.basicConfig(
Expand All @@ -34,7 +44,4 @@
)
sys.exit(1)

server = AgentServer(sys.argv[1], sys.argv[2], sys.argv[3])
server.start()
server.grpc_server.wait_for_termination()
server.stop()
asyncio.run(main(*sys.argv[1:]))
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# limitations under the License.
#

import concurrent
import asyncio
import importlib
import json
import os
Expand All @@ -23,7 +23,7 @@
import threading
from concurrent.futures import Future
from io import BytesIO
from typing import Iterable, Union, List, Tuple, Any, Optional, Dict
from typing import Union, List, Tuple, Any, Optional, Dict, AsyncIterable

import fastavro
import grpc
Expand Down Expand Up @@ -78,30 +78,28 @@ def __init__(self, agent: Union[Agent, Source, Sink, Processor]):
self.schemas = {}
self.client_schemas = {}

def agent_info(self, _, __):
info = call_method_if_exists(self.agent, "agent_info") or {}
async def agent_info(self, _, __):
info = await acall_method_if_exists(self.agent, "agent_info") or {}
return InfoResponse(json_info=json.dumps(info))

def get_topic_producer_records(self, request_iterator, context):
# TODO: to be implementedbla
for _ in request_iterator:
yield None
async def get_topic_producer_records(self, request_iterator, context):
# TODO: to be implemented
async for _ in request_iterator:
yield

def read(self, requests: Iterable[SourceRequest], _):
async def read(self, requests: AsyncIterable[SourceRequest], _):
read_records = {}
op_result = []
read_thread = threading.Thread(
target=self.handle_read_requests,
args=(requests, read_records, op_result),
)
last_record_id = 0
read_thread.start()
read_requests_task = asyncio.create_task(
self.handle_read_requests(requests, read_records, op_result)
)
while True:
if len(op_result) > 0:
if op_result[0] is True:
break
raise op_result[0]
records = self.agent.read()
records = await asyncio.to_thread(self.agent.read)
if len(records) > 0:
records = [wrap_in_record(record) for record in records]
grpc_records = []
Expand All @@ -115,25 +113,25 @@ def read(self, requests: Iterable[SourceRequest], _):
grpc_records[i].record_id = last_record_id
read_records[last_record_id] = record
yield SourceResponse(records=grpc_records)
read_thread.join()
read_requests_task.cancel()

def handle_read_requests(
async def handle_read_requests(
self,
requests: Iterable[SourceRequest],
requests: AsyncIterable[SourceRequest],
read_records: Dict[int, Record],
read_result,
):
try:
for request in requests:
async for request in requests:
if len(request.committed_records) > 0:
for record_id in request.committed_records:
record = read_records.pop(record_id, None)
if record is not None:
call_method_if_exists(self.agent, "commit", record)
await acall_method_if_exists(self.agent, "commit", record)
if request.HasField("permanent_failure"):
failure = request.permanent_failure
record = read_records.pop(failure.record_id, None)
call_method_if_exists(
await acall_method_if_exists(
self.agent,
"permanent_failure",
record,
Expand All @@ -159,83 +157,49 @@ def handle_requests(handler, requests):
pass
thread.join()

def process(self, requests: Iterable[ProcessorRequest], _):
return self.handle_requests(self.handle_process_requests, requests)

def process_record(
self, source_record, get_processed_fn, get_processed_args, process_results
):
grpc_result = ProcessorResult(record_id=source_record.record_id)
try:
processed_records = get_processed_fn(*get_processed_args)
if isinstance(processed_records, Future):
processed_records.add_done_callback(
lambda f: self.process_record(
source_record, f.result, (), process_results
)
)
else:
for record in processed_records:
schemas, grpc_record = self.to_grpc_record(wrap_in_record(record))
for schema in schemas:
process_results.put(ProcessorResponse(schema=schema))
grpc_result.records.append(grpc_record)
process_results.put(ProcessorResponse(results=[grpc_result]))
except Exception as e:
grpc_result.error = str(e)
process_results.put(ProcessorResponse(results=[grpc_result]))

def handle_process_requests(
self, requests: Iterable[ProcessorRequest], process_results
):
for request in requests:
async def process(self, requests: AsyncIterable[ProcessorRequest], _):
async for request in requests:
if request.HasField("schema"):
schema = fastavro.parse_schema(json.loads(request.schema.value))
self.client_schemas[request.schema.schema_id] = schema
if len(request.records) > 0:
for source_record in request.records:
self.process_record(
source_record,
lambda r: self.agent.process(self.from_grpc_record(r)),
(source_record,),
process_results,
)
process_results.put(True)

def write(self, requests: Iterable[SinkRequest], _):
return self.handle_requests(self.handle_write_requests, requests)

def write_record(
self, source_record, get_written_fn, get_written_args, write_results
):
try:
result = get_written_fn(*get_written_args)
if isinstance(result, Future):
result.add_done_callback(
lambda f: self.write_record(
source_record, f.result, (), write_results
)
)
else:
write_results.put(SinkResponse(record_id=source_record.record_id))
except Exception as e:
write_results.put(
SinkResponse(record_id=source_record.record_id, error=str(e))
)

def handle_write_requests(self, requests: Iterable[SinkRequest], write_results):
for request in requests:
grpc_result = ProcessorResult(record_id=source_record.record_id)
try:
processed_records = await asyncio.to_thread(
self.agent.process, self.from_grpc_record(source_record)
)
if isinstance(processed_records, Future):
processed_records = await asyncio.wrap_future(
processed_records
)
for record in processed_records:
schemas, grpc_record = self.to_grpc_record(
wrap_in_record(record)
)
for schema in schemas:
yield ProcessorResponse(schema=schema)
grpc_result.records.append(grpc_record)
yield ProcessorResponse(results=[grpc_result])
except Exception as e:
grpc_result.error = str(e)
yield ProcessorResponse(results=[grpc_result])

async def write(self, requests: AsyncIterable[SinkRequest], context):
async for request in requests:
if request.HasField("schema"):
schema = fastavro.parse_schema(json.loads(request.schema.value))
self.client_schemas[request.schema.schema_id] = schema
if request.HasField("record"):
self.write_record(
request.record,
lambda r: self.agent.write(self.from_grpc_record(r)),
(request.record,),
write_results,
)
write_results.put(True)
try:
result = await asyncio.to_thread(
self.agent.write, self.from_grpc_record(request.record)
)
if isinstance(result, Future):
await asyncio.wrap_future(result)
yield SinkResponse(record_id=request.record.record_id)
except Exception as e:
yield SinkResponse(record_id=request.record.record_id, error=str(e))

def from_grpc_record(self, record: GrpcRecord) -> SimpleRecord:
return RecordWithId(
Expand Down Expand Up @@ -333,6 +297,12 @@ def call_method_if_exists(klass, method, *args, **kwargs):
return None


async def acall_method_if_exists(klass, method, *args, **kwargs):
return await asyncio.to_thread(
call_method_if_exists, klass, method, *args, **kwargs
)


class MainExecutor(threading.Thread):
def __init__(self, onError, klass, method, *args, **kwargs):
threading.Thread.__init__(self)
Expand Down Expand Up @@ -364,17 +334,16 @@ def call_method_new_thread_if_exists(klass, methodName, *args, **kwargs):
def crash_process():
logging.error("Main method with an error. Exiting process.")
os.exit(1)
return


def init_agent(configuration, context) -> Agent:
async def init_agent(configuration, context) -> Agent:
full_class_name = configuration["className"]
class_name = full_class_name.split(".")[-1]
module_name = full_class_name[: -len(class_name) - 1]
module = importlib.import_module(module_name)
agent = getattr(module, class_name)()
context_impl = DefaultAgentContext(configuration, context)
call_method_if_exists(agent, "init", configuration, context_impl)
await acall_method_if_exists(agent, "init", configuration, context_impl)
return agent


Expand All @@ -388,36 +357,35 @@ def get_persistent_state_directory(self) -> Optional[str]:


class AgentServer(object):
def __init__(self, target: str, config: str, context: str):
self.thread_pool = concurrent.futures.ThreadPoolExecutor(max_workers=10)
def __init__(self, target: str):
self.target = target
self.grpc_server = grpc.server(self.thread_pool)
self.grpc_server = grpc.aio.server()
self.port = self.grpc_server.add_insecure_port(target)
self.agent = None

async def init(self, config, context):
configuration = json.loads(config)
logging.debug("Configuration: " + json.dumps(configuration))
environment = configuration.get("environment", [])
logging.debug("Environment: " + json.dumps(environment))

for env in environment:
key = env["key"]
value = env["value"]
logging.debug(f"Setting environment variable {key}={value}")
os.environ[key] = value
self.agent = await init_agent(configuration, json.loads(context))

self.agent = init_agent(configuration, json.loads(context))

def start(self):
call_method_if_exists(self.agent, "start")
async def start(self):
await acall_method_if_exists(self.agent, "start")
call_method_new_thread_if_exists(self.agent, "main", crash_process)

agent_pb2_grpc.add_AgentServiceServicer_to_server(
AgentService(self.agent), self.grpc_server
)
self.grpc_server.start()

await self.grpc_server.start()
logging.info("GRPC Server started, listening on " + self.target)

def stop(self):
self.grpc_server.stop(None)
call_method_if_exists(self.agent, "close")
self.thread_pool.shutdown(wait=True)
async def stop(self):
await self.grpc_server.stop(None)
await acall_method_if_exists(self.agent, "close")
Original file line number Diff line number Diff line change
Expand Up @@ -29,18 +29,17 @@ def __init__(self, class_name, agent_config={}, context={}):
self.config["className"] = class_name
self.context = context
self.server: Optional[AgentServer] = None
self.channel: Optional[grpc.Channel] = None
self.channel: Optional[grpc.aio.Channel] = None
self.stub: Optional[AgentServiceStub] = None

def __enter__(self):
self.server = AgentServer(
"[::]:0", json.dumps(self.config), json.dumps(self.context)
)
self.server.start()
self.channel = grpc.insecure_channel("localhost:%d" % self.server.port)
async def __aenter__(self):
self.server = AgentServer("[::]:0")
await self.server.init(json.dumps(self.config), json.dumps(self.context))
await self.server.start()
self.channel = grpc.aio.insecure_channel("localhost:%d" % self.server.port)
self.stub = AgentServiceStub(channel=self.channel)
return self

def __exit__(self, *args):
self.channel.close()
self.server.stop()
async def __aexit__(self, *args):
await self.channel.close()
await self.server.stop()
Loading
Loading