Skip to content

Commit

Permalink
Add grpc calls in dataset generator.
Browse files Browse the repository at this point in the history
Tested in local machines. Have assertiona failed on server side which
causes the vllm worker to crash. This may be caused by sending empty
prompt to the server as described in
vllm-project/vllm#7632 and
vllm-project/vllm#7746. Need to further
inspection on this later.
  • Loading branch information
Anyonering committed Oct 2, 2024
1 parent 3043967 commit cced6e7
Show file tree
Hide file tree
Showing 10 changed files with 620 additions and 15 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
grpc/__pycache__
66 changes: 66 additions & 0 deletions grpc/async_server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
import asyncio
import logging

import grpc
import chat_pb2
import chat_pb2_grpc

import os

import vllm
from vllm.entrypoints.llm import LLM
from vllm.engine.arg_utils import AsyncEngineArgs

import asyncio
import time
import uuid
from typing import Optional


MODEL_PATH = '~/personal/projects/vllm_inference/model_data/opt-1.3b/'

class LlmEngine(chat_pb2_grpc.LlmEngineServicer):
def __init__(self, *args, **kwargs):
model_dir = os.path.expanduser(MODEL_PATH)
self.engine = vllm.AsyncLLMEngine.from_engine_args(
AsyncEngineArgs(
model=model_dir,
enforce_eager=True,
trust_remote_code=True,
max_model_len=2048,
)
)
async def processChatReq(self, request: chat_pb2.ChatReq, context: grpc.aio.ServicerContext):
results_generator = self.engine.generate(
request.prompt,
vllm.SamplingParams(temperature=0.8, top_p=0.95, max_tokens=2048, min_tokens=20,),
request_id=request.request_id,
session_id = request.session_id,
# refill_requests=refill_requests
)
final_output = None
async for request_output in results_generator:
final_output = request_output

# prompt = final_output.prompt
text_output = [output.text for output in final_output.outputs]
return chat_pb2.ChatResp(answer=text_output[0])

async def processInfoReq(self, request: chat_pb2.InfoReq, context: grpc.aio.ServicerContext):
self.engine.engine.scheduler[0].refill_requests.append(request.session_id)
return chat_pb2.InfoResp(success=True)


async def serve() -> None:
server = grpc.aio.server()
chat_pb2_grpc.add_LlmEngineServicer_to_server(LlmEngine(), server)
listen_addr = "[::]:50051"
server.add_insecure_port(listen_addr)
logging.info("Starting server on %s", listen_addr)
await server.start()
await server.wait_for_termination()


if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)
asyncio.run(serve())
24 changes: 24 additions & 0 deletions grpc/chat.proto
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
syntax = "proto3";
package ChatVllm;

service LlmEngine {
rpc processChatReq(ChatReq) returns (ChatResp){}
rpc processInfoReq(InfoReq) returns (InfoResp){}
}

message ChatReq {
string prompt = 1;
int32 session_id = 2;
string request_id = 3;
}

message InfoReq {
int32 session_id = 1;
}
message InfoResp {
bool success = 1;
}

message ChatResp {
string answer = 1;
}
44 changes: 44 additions & 0 deletions grpc/chat_pb2.py

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

33 changes: 33 additions & 0 deletions grpc/chat_pb2.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from google.protobuf import descriptor as _descriptor
from google.protobuf import message as _message
from typing import ClassVar as _ClassVar, Optional as _Optional

DESCRIPTOR: _descriptor.FileDescriptor

class ChatReq(_message.Message):
__slots__ = ("prompt", "session_id", "request_id")
PROMPT_FIELD_NUMBER: _ClassVar[int]
SESSION_ID_FIELD_NUMBER: _ClassVar[int]
REQUEST_ID_FIELD_NUMBER: _ClassVar[int]
prompt: str
session_id: int
request_id: str
def __init__(self, prompt: _Optional[str] = ..., session_id: _Optional[int] = ..., request_id: _Optional[str] = ...) -> None: ...

class InfoReq(_message.Message):
__slots__ = ("session_id",)
SESSION_ID_FIELD_NUMBER: _ClassVar[int]
session_id: int
def __init__(self, session_id: _Optional[int] = ...) -> None: ...

class InfoResp(_message.Message):
__slots__ = ("success",)
SUCCESS_FIELD_NUMBER: _ClassVar[int]
success: bool
def __init__(self, success: bool = ...) -> None: ...

class ChatResp(_message.Message):
__slots__ = ("answer",)
ANSWER_FIELD_NUMBER: _ClassVar[int]
answer: str
def __init__(self, answer: _Optional[str] = ...) -> None: ...
140 changes: 140 additions & 0 deletions grpc/chat_pb2_grpc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
"""Client and server classes corresponding to protobuf-defined services."""
import grpc
import warnings

import chat_pb2 as chat__pb2

GRPC_GENERATED_VERSION = '1.66.0'
GRPC_VERSION = grpc.__version__
_version_not_supported = False

try:
from grpc._utilities import first_version_is_lower
_version_not_supported = first_version_is_lower(GRPC_VERSION, GRPC_GENERATED_VERSION)
except ImportError:
_version_not_supported = True

if _version_not_supported:
raise RuntimeError(
f'The grpc package installed is at version {GRPC_VERSION},'
+ f' but the generated code in chat_pb2_grpc.py depends on'
+ f' grpcio>={GRPC_GENERATED_VERSION}.'
+ f' Please upgrade your grpc module to grpcio>={GRPC_GENERATED_VERSION}'
+ f' or downgrade your generated code using grpcio-tools<={GRPC_VERSION}.'
)


class LlmEngineStub(object):
"""Missing associated documentation comment in .proto file."""

def __init__(self, channel):
"""Constructor.
Args:
channel: A grpc.Channel.
"""
self.processChatReq = channel.unary_unary(
'/ChatVllm.LlmEngine/processChatReq',
request_serializer=chat__pb2.ChatReq.SerializeToString,
response_deserializer=chat__pb2.ChatResp.FromString,
_registered_method=True)
self.processInfoReq = channel.unary_unary(
'/ChatVllm.LlmEngine/processInfoReq',
request_serializer=chat__pb2.InfoReq.SerializeToString,
response_deserializer=chat__pb2.InfoResp.FromString,
_registered_method=True)


class LlmEngineServicer(object):
"""Missing associated documentation comment in .proto file."""

def processChatReq(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')

def processInfoReq(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')


def add_LlmEngineServicer_to_server(servicer, server):
rpc_method_handlers = {
'processChatReq': grpc.unary_unary_rpc_method_handler(
servicer.processChatReq,
request_deserializer=chat__pb2.ChatReq.FromString,
response_serializer=chat__pb2.ChatResp.SerializeToString,
),
'processInfoReq': grpc.unary_unary_rpc_method_handler(
servicer.processInfoReq,
request_deserializer=chat__pb2.InfoReq.FromString,
response_serializer=chat__pb2.InfoResp.SerializeToString,
),
}
generic_handler = grpc.method_handlers_generic_handler(
'ChatVllm.LlmEngine', rpc_method_handlers)
server.add_generic_rpc_handlers((generic_handler,))
server.add_registered_method_handlers('ChatVllm.LlmEngine', rpc_method_handlers)


# This class is part of an EXPERIMENTAL API.
class LlmEngine(object):
"""Missing associated documentation comment in .proto file."""

@staticmethod
def processChatReq(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(
request,
target,
'/ChatVllm.LlmEngine/processChatReq',
chat__pb2.ChatReq.SerializeToString,
chat__pb2.ChatResp.FromString,
options,
channel_credentials,
insecure,
call_credentials,
compression,
wait_for_ready,
timeout,
metadata,
_registered_method=True)

@staticmethod
def processInfoReq(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(
request,
target,
'/ChatVllm.LlmEngine/processInfoReq',
chat__pb2.InfoReq.SerializeToString,
chat__pb2.InfoResp.FromString,
options,
channel_credentials,
insecure,
call_credentials,
compression,
wait_for_ready,
timeout,
metadata,
_registered_method=True)
60 changes: 60 additions & 0 deletions grpc/client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@

import asyncio
import time
import logging

import grpc
import chat_pb2
import chat_pb2_grpc


task_list = []
def blocking_task():
print("doing other work")
time.sleep(3)
print("sleeping call ----- after")
t = time.localtime()
current_time = time.strftime("%H:%M:%S", t)
print(current_time)

async def run() -> None:
async with grpc.aio.insecure_channel("localhost:50051") as channel:
stub = chat_pb2_grpc.LlmEngineStub(channel)
request = chat_pb2.ChatReq(prompt="Who is the most powerful person in the world?",request_id="4",session_id=4)
task = asyncio.create_task(make_grpc_request(stub, request))
print("sleeping call ----- before")
t = time.localtime()
current_time = time.strftime("%H:%M:%S", t)
print(current_time)
coro = asyncio.to_thread(blocking_task)
await coro
task_list.append(task)
t = time.localtime()
current_time = time.strftime("%H:%M:%S", t)
print("run function time: ",current_time)
# await task_list[0]
# response = stub.processChatReq(chat_pb2.ChatReq(prompt="Who is the most powerful person in the world?",request_id="1",session_id=1))
# stub2 = chat_pb2_grpc.LlmEngineStub(channel)
# resp2 = await stub.processInfoReq(chat_pb2.InfoReq(session_id=0))
# task_list.append(response)
# real_resp = await response
# print("Greeter client received: " + real_resp.answer)
# with grpc.insecure_channel("localhost:50051") as channel:
# stub = chat_pb2_grpc.LlmEngineStub(channel)
# stub.processInfoReq(chat_pb2.InfoReq(session_id=0))
# print("Greeter client received: " + response.answer)
# print(resp2)

async def make_grpc_request(stub, request):
t = time.localtime()
current_time = time.strftime("%H:%M:%S", t)
print("time making request: ",current_time)
response = await stub.processChatReq(request)
print("Response received:", response.answer)
t = time.localtime()
current_time = time.strftime("%H:%M:%S", t)
print("time finished request: ",current_time)
if __name__ == "__main__":
logging.basicConfig()
asyncio.run(run())

Loading

0 comments on commit cced6e7

Please sign in to comment.