-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add grpc calls in dataset generator.
Tested in local machines. Have assertiona failed on server side which causes the vllm worker to crash. This may be caused by sending empty prompt to the server as described in vllm-project/vllm#7632 and vllm-project/vllm#7746. Need to further inspection on this later.
- Loading branch information
1 parent
3043967
commit cced6e7
Showing
10 changed files
with
620 additions
and
15 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
grpc/__pycache__ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
import asyncio | ||
import logging | ||
|
||
import grpc | ||
import chat_pb2 | ||
import chat_pb2_grpc | ||
|
||
import os | ||
|
||
import vllm | ||
from vllm.entrypoints.llm import LLM | ||
from vllm.engine.arg_utils import AsyncEngineArgs | ||
|
||
import asyncio | ||
import time | ||
import uuid | ||
from typing import Optional | ||
|
||
|
||
MODEL_PATH = '~/personal/projects/vllm_inference/model_data/opt-1.3b/' | ||
|
||
class LlmEngine(chat_pb2_grpc.LlmEngineServicer): | ||
def __init__(self, *args, **kwargs): | ||
model_dir = os.path.expanduser(MODEL_PATH) | ||
self.engine = vllm.AsyncLLMEngine.from_engine_args( | ||
AsyncEngineArgs( | ||
model=model_dir, | ||
enforce_eager=True, | ||
trust_remote_code=True, | ||
max_model_len=2048, | ||
) | ||
) | ||
async def processChatReq(self, request: chat_pb2.ChatReq, context: grpc.aio.ServicerContext): | ||
results_generator = self.engine.generate( | ||
request.prompt, | ||
vllm.SamplingParams(temperature=0.8, top_p=0.95, max_tokens=2048, min_tokens=20,), | ||
request_id=request.request_id, | ||
session_id = request.session_id, | ||
# refill_requests=refill_requests | ||
) | ||
final_output = None | ||
async for request_output in results_generator: | ||
final_output = request_output | ||
|
||
# prompt = final_output.prompt | ||
text_output = [output.text for output in final_output.outputs] | ||
return chat_pb2.ChatResp(answer=text_output[0]) | ||
|
||
async def processInfoReq(self, request: chat_pb2.InfoReq, context: grpc.aio.ServicerContext): | ||
self.engine.engine.scheduler[0].refill_requests.append(request.session_id) | ||
return chat_pb2.InfoResp(success=True) | ||
|
||
|
||
async def serve() -> None: | ||
server = grpc.aio.server() | ||
chat_pb2_grpc.add_LlmEngineServicer_to_server(LlmEngine(), server) | ||
listen_addr = "[::]:50051" | ||
server.add_insecure_port(listen_addr) | ||
logging.info("Starting server on %s", listen_addr) | ||
await server.start() | ||
await server.wait_for_termination() | ||
|
||
|
||
if __name__ == "__main__": | ||
logging.basicConfig(level=logging.INFO) | ||
asyncio.run(serve()) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
syntax = "proto3"; | ||
package ChatVllm; | ||
|
||
service LlmEngine { | ||
rpc processChatReq(ChatReq) returns (ChatResp){} | ||
rpc processInfoReq(InfoReq) returns (InfoResp){} | ||
} | ||
|
||
message ChatReq { | ||
string prompt = 1; | ||
int32 session_id = 2; | ||
string request_id = 3; | ||
} | ||
|
||
message InfoReq { | ||
int32 session_id = 1; | ||
} | ||
message InfoResp { | ||
bool success = 1; | ||
} | ||
|
||
message ChatResp { | ||
string answer = 1; | ||
} |
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
from google.protobuf import descriptor as _descriptor | ||
from google.protobuf import message as _message | ||
from typing import ClassVar as _ClassVar, Optional as _Optional | ||
|
||
DESCRIPTOR: _descriptor.FileDescriptor | ||
|
||
class ChatReq(_message.Message): | ||
__slots__ = ("prompt", "session_id", "request_id") | ||
PROMPT_FIELD_NUMBER: _ClassVar[int] | ||
SESSION_ID_FIELD_NUMBER: _ClassVar[int] | ||
REQUEST_ID_FIELD_NUMBER: _ClassVar[int] | ||
prompt: str | ||
session_id: int | ||
request_id: str | ||
def __init__(self, prompt: _Optional[str] = ..., session_id: _Optional[int] = ..., request_id: _Optional[str] = ...) -> None: ... | ||
|
||
class InfoReq(_message.Message): | ||
__slots__ = ("session_id",) | ||
SESSION_ID_FIELD_NUMBER: _ClassVar[int] | ||
session_id: int | ||
def __init__(self, session_id: _Optional[int] = ...) -> None: ... | ||
|
||
class InfoResp(_message.Message): | ||
__slots__ = ("success",) | ||
SUCCESS_FIELD_NUMBER: _ClassVar[int] | ||
success: bool | ||
def __init__(self, success: bool = ...) -> None: ... | ||
|
||
class ChatResp(_message.Message): | ||
__slots__ = ("answer",) | ||
ANSWER_FIELD_NUMBER: _ClassVar[int] | ||
answer: str | ||
def __init__(self, answer: _Optional[str] = ...) -> None: ... |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,140 @@ | ||
# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! | ||
"""Client and server classes corresponding to protobuf-defined services.""" | ||
import grpc | ||
import warnings | ||
|
||
import chat_pb2 as chat__pb2 | ||
|
||
GRPC_GENERATED_VERSION = '1.66.0' | ||
GRPC_VERSION = grpc.__version__ | ||
_version_not_supported = False | ||
|
||
try: | ||
from grpc._utilities import first_version_is_lower | ||
_version_not_supported = first_version_is_lower(GRPC_VERSION, GRPC_GENERATED_VERSION) | ||
except ImportError: | ||
_version_not_supported = True | ||
|
||
if _version_not_supported: | ||
raise RuntimeError( | ||
f'The grpc package installed is at version {GRPC_VERSION},' | ||
+ f' but the generated code in chat_pb2_grpc.py depends on' | ||
+ f' grpcio>={GRPC_GENERATED_VERSION}.' | ||
+ f' Please upgrade your grpc module to grpcio>={GRPC_GENERATED_VERSION}' | ||
+ f' or downgrade your generated code using grpcio-tools<={GRPC_VERSION}.' | ||
) | ||
|
||
|
||
class LlmEngineStub(object): | ||
"""Missing associated documentation comment in .proto file.""" | ||
|
||
def __init__(self, channel): | ||
"""Constructor. | ||
Args: | ||
channel: A grpc.Channel. | ||
""" | ||
self.processChatReq = channel.unary_unary( | ||
'/ChatVllm.LlmEngine/processChatReq', | ||
request_serializer=chat__pb2.ChatReq.SerializeToString, | ||
response_deserializer=chat__pb2.ChatResp.FromString, | ||
_registered_method=True) | ||
self.processInfoReq = channel.unary_unary( | ||
'/ChatVllm.LlmEngine/processInfoReq', | ||
request_serializer=chat__pb2.InfoReq.SerializeToString, | ||
response_deserializer=chat__pb2.InfoResp.FromString, | ||
_registered_method=True) | ||
|
||
|
||
class LlmEngineServicer(object): | ||
"""Missing associated documentation comment in .proto file.""" | ||
|
||
def processChatReq(self, request, context): | ||
"""Missing associated documentation comment in .proto file.""" | ||
context.set_code(grpc.StatusCode.UNIMPLEMENTED) | ||
context.set_details('Method not implemented!') | ||
raise NotImplementedError('Method not implemented!') | ||
|
||
def processInfoReq(self, request, context): | ||
"""Missing associated documentation comment in .proto file.""" | ||
context.set_code(grpc.StatusCode.UNIMPLEMENTED) | ||
context.set_details('Method not implemented!') | ||
raise NotImplementedError('Method not implemented!') | ||
|
||
|
||
def add_LlmEngineServicer_to_server(servicer, server): | ||
rpc_method_handlers = { | ||
'processChatReq': grpc.unary_unary_rpc_method_handler( | ||
servicer.processChatReq, | ||
request_deserializer=chat__pb2.ChatReq.FromString, | ||
response_serializer=chat__pb2.ChatResp.SerializeToString, | ||
), | ||
'processInfoReq': grpc.unary_unary_rpc_method_handler( | ||
servicer.processInfoReq, | ||
request_deserializer=chat__pb2.InfoReq.FromString, | ||
response_serializer=chat__pb2.InfoResp.SerializeToString, | ||
), | ||
} | ||
generic_handler = grpc.method_handlers_generic_handler( | ||
'ChatVllm.LlmEngine', rpc_method_handlers) | ||
server.add_generic_rpc_handlers((generic_handler,)) | ||
server.add_registered_method_handlers('ChatVllm.LlmEngine', rpc_method_handlers) | ||
|
||
|
||
# This class is part of an EXPERIMENTAL API. | ||
class LlmEngine(object): | ||
"""Missing associated documentation comment in .proto file.""" | ||
|
||
@staticmethod | ||
def processChatReq(request, | ||
target, | ||
options=(), | ||
channel_credentials=None, | ||
call_credentials=None, | ||
insecure=False, | ||
compression=None, | ||
wait_for_ready=None, | ||
timeout=None, | ||
metadata=None): | ||
return grpc.experimental.unary_unary( | ||
request, | ||
target, | ||
'/ChatVllm.LlmEngine/processChatReq', | ||
chat__pb2.ChatReq.SerializeToString, | ||
chat__pb2.ChatResp.FromString, | ||
options, | ||
channel_credentials, | ||
insecure, | ||
call_credentials, | ||
compression, | ||
wait_for_ready, | ||
timeout, | ||
metadata, | ||
_registered_method=True) | ||
|
||
@staticmethod | ||
def processInfoReq(request, | ||
target, | ||
options=(), | ||
channel_credentials=None, | ||
call_credentials=None, | ||
insecure=False, | ||
compression=None, | ||
wait_for_ready=None, | ||
timeout=None, | ||
metadata=None): | ||
return grpc.experimental.unary_unary( | ||
request, | ||
target, | ||
'/ChatVllm.LlmEngine/processInfoReq', | ||
chat__pb2.InfoReq.SerializeToString, | ||
chat__pb2.InfoResp.FromString, | ||
options, | ||
channel_credentials, | ||
insecure, | ||
call_credentials, | ||
compression, | ||
wait_for_ready, | ||
timeout, | ||
metadata, | ||
_registered_method=True) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
|
||
import asyncio | ||
import time | ||
import logging | ||
|
||
import grpc | ||
import chat_pb2 | ||
import chat_pb2_grpc | ||
|
||
|
||
task_list = [] | ||
def blocking_task(): | ||
print("doing other work") | ||
time.sleep(3) | ||
print("sleeping call ----- after") | ||
t = time.localtime() | ||
current_time = time.strftime("%H:%M:%S", t) | ||
print(current_time) | ||
|
||
async def run() -> None: | ||
async with grpc.aio.insecure_channel("localhost:50051") as channel: | ||
stub = chat_pb2_grpc.LlmEngineStub(channel) | ||
request = chat_pb2.ChatReq(prompt="Who is the most powerful person in the world?",request_id="4",session_id=4) | ||
task = asyncio.create_task(make_grpc_request(stub, request)) | ||
print("sleeping call ----- before") | ||
t = time.localtime() | ||
current_time = time.strftime("%H:%M:%S", t) | ||
print(current_time) | ||
coro = asyncio.to_thread(blocking_task) | ||
await coro | ||
task_list.append(task) | ||
t = time.localtime() | ||
current_time = time.strftime("%H:%M:%S", t) | ||
print("run function time: ",current_time) | ||
# await task_list[0] | ||
# response = stub.processChatReq(chat_pb2.ChatReq(prompt="Who is the most powerful person in the world?",request_id="1",session_id=1)) | ||
# stub2 = chat_pb2_grpc.LlmEngineStub(channel) | ||
# resp2 = await stub.processInfoReq(chat_pb2.InfoReq(session_id=0)) | ||
# task_list.append(response) | ||
# real_resp = await response | ||
# print("Greeter client received: " + real_resp.answer) | ||
# with grpc.insecure_channel("localhost:50051") as channel: | ||
# stub = chat_pb2_grpc.LlmEngineStub(channel) | ||
# stub.processInfoReq(chat_pb2.InfoReq(session_id=0)) | ||
# print("Greeter client received: " + response.answer) | ||
# print(resp2) | ||
|
||
async def make_grpc_request(stub, request): | ||
t = time.localtime() | ||
current_time = time.strftime("%H:%M:%S", t) | ||
print("time making request: ",current_time) | ||
response = await stub.processChatReq(request) | ||
print("Response received:", response.answer) | ||
t = time.localtime() | ||
current_time = time.strftime("%H:%M:%S", t) | ||
print("time finished request: ",current_time) | ||
if __name__ == "__main__": | ||
logging.basicConfig() | ||
asyncio.run(run()) | ||
|
Oops, something went wrong.