Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

first proto example #92

Merged
merged 2 commits into from
Jun 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions docs/Protobuf.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# Create new Model


## Generate a stubbed out Python server

```shell
pip install grpc-tools
```


To create the generated files use this call
```shell
python -m grpc_tools.protoc --proto_path=proto generate.proto --python_out=llms/repeater --pyi_out=llms/repeater --grpc_python_out=llms/repeater
```

which creates the following files
* generate_pb2_grpc.py
* generate_pb2.py
* generate_pb2.pyi

and then those objects are used in the `repeater.py` to implement the grpc service
33 changes: 33 additions & 0 deletions llms/repeater/generate_pb2.py

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

51 changes: 51 additions & 0 deletions llms/repeater/generate_pb2.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
from google.protobuf.internal import containers as _containers
from google.protobuf import descriptor as _descriptor
from google.protobuf import message as _message
from typing import ClassVar as _ClassVar, Mapping as _Mapping, Optional as _Optional

DESCRIPTOR: _descriptor.FileDescriptor

class CompletionRequest(_message.Message):
__slots__ = ["best_of", "echo", "frequence_penalty", "logit_bias", "logprobs", "max_tokens", "n", "presence_penalty", "prompt", "stop", "stream", "suffix", "temperature", "top_p"]
class LogitBiasEntry(_message.Message):
__slots__ = ["key", "value"]
KEY_FIELD_NUMBER: _ClassVar[int]
VALUE_FIELD_NUMBER: _ClassVar[int]
key: str
value: float
def __init__(self, key: _Optional[str] = ..., value: _Optional[float] = ...) -> None: ...
BEST_OF_FIELD_NUMBER: _ClassVar[int]
ECHO_FIELD_NUMBER: _ClassVar[int]
FREQUENCE_PENALTY_FIELD_NUMBER: _ClassVar[int]
LOGIT_BIAS_FIELD_NUMBER: _ClassVar[int]
LOGPROBS_FIELD_NUMBER: _ClassVar[int]
MAX_TOKENS_FIELD_NUMBER: _ClassVar[int]
N_FIELD_NUMBER: _ClassVar[int]
PRESENCE_PENALTY_FIELD_NUMBER: _ClassVar[int]
PROMPT_FIELD_NUMBER: _ClassVar[int]
STOP_FIELD_NUMBER: _ClassVar[int]
STREAM_FIELD_NUMBER: _ClassVar[int]
SUFFIX_FIELD_NUMBER: _ClassVar[int]
TEMPERATURE_FIELD_NUMBER: _ClassVar[int]
TOP_P_FIELD_NUMBER: _ClassVar[int]
best_of: int
echo: bool
frequence_penalty: float
logit_bias: _containers.ScalarMap[str, float]
logprobs: int
max_tokens: int
n: int
presence_penalty: float
prompt: str
stop: str
stream: bool
suffix: str
temperature: float
top_p: float
def __init__(self, prompt: _Optional[str] = ..., suffix: _Optional[str] = ..., max_tokens: _Optional[int] = ..., temperature: _Optional[float] = ..., top_p: _Optional[float] = ..., n: _Optional[int] = ..., stream: bool = ..., logprobs: _Optional[int] = ..., echo: bool = ..., stop: _Optional[str] = ..., presence_penalty: _Optional[float] = ..., frequence_penalty: _Optional[float] = ..., best_of: _Optional[int] = ..., logit_bias: _Optional[_Mapping[str, float]] = ...) -> None: ...

class CompletionResponse(_message.Message):
__slots__ = ["completion"]
COMPLETION_FIELD_NUMBER: _ClassVar[int]
completion: str
def __init__(self, completion: _Optional[str] = ...) -> None: ...
66 changes: 66 additions & 0 deletions llms/repeater/generate_pb2_grpc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
"""Client and server classes corresponding to protobuf-defined services."""
import grpc

import generate_pb2 as generate__pb2


class GenerateServiceStub(object):
"""Missing associated documentation comment in .proto file."""

def __init__(self, channel):
"""Constructor.
Args:
channel: A grpc.Channel.
"""
self.Complete = channel.unary_unary(
'/leapfrog.GenerateService/Complete',
request_serializer=generate__pb2.CompletionRequest.SerializeToString,
response_deserializer=generate__pb2.CompletionResponse.FromString,
)


class GenerateServiceServicer(object):
"""Missing associated documentation comment in .proto file."""

def Complete(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')


def add_GenerateServiceServicer_to_server(servicer, server):
rpc_method_handlers = {
'Complete': grpc.unary_unary_rpc_method_handler(
servicer.Complete,
request_deserializer=generate__pb2.CompletionRequest.FromString,
response_serializer=generate__pb2.CompletionResponse.SerializeToString,
),
}
generic_handler = grpc.method_handlers_generic_handler(
'leapfrog.GenerateService', rpc_method_handlers)
server.add_generic_rpc_handlers((generic_handler,))


# This class is part of an EXPERIMENTAL API.
class GenerateService(object):
"""Missing associated documentation comment in .proto file."""

@staticmethod
def Complete(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(request, target, '/leapfrog.GenerateService/Complete',
generate__pb2.CompletionRequest.SerializeToString,
generate__pb2.CompletionResponse.FromString,
options, channel_credentials,
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
32 changes: 32 additions & 0 deletions llms/repeater/repeater.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from concurrent import futures
import grpc
import generate_pb2
import generate_pb2_grpc


class Repeater(object):
def Complete(self, request, context):

result = request.prompt # just returns what's provided
return generate_pb2.CompletionResponse(completion=result)


def serve():
# Create a gRPC server
server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
generate_pb2_grpc.add_GenerateServiceServicer_to_server(Repeater(), server)

# Listen on port 50051
print('Starting server. Listening on port 50051.')
server.add_insecure_port('[::]:50051')
server.start()

# Keep thread alive
try:
while True:
pass
except KeyboardInterrupt:
server.stop(0)

if __name__ == "__main__":
serve()
24 changes: 24 additions & 0 deletions llms/repeater/test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import grpc
import generate_pb2
import generate_pb2_grpc

def run():
# Set up a channel to the server
with grpc.insecure_channel('localhost:50051') as channel:
# Instantiate a stub (client)
stub = generate_pb2_grpc.GenerateServiceStub(channel)

# Create a request
request = generate_pb2.CompletionRequest(
prompt="Hello, Chatbot!",
# add other parameters as necessary
)

# Make a call to the server and get a response
response = stub.Complete(request)

# Print the response
print("Received response: ", response.completion)

if __name__ == "__main__":
run()
30 changes: 30 additions & 0 deletions proto/generate.proto
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
syntax = "proto3";

package leapfrog;

// CompletionRequest is the payload to request completion
message CompletionRequest {
string prompt = 1;
string suffix = 2;
int32 max_tokens = 3;
float temperature = 4;
float top_p = 5;
int32 n = 6;
bool stream = 7;
int32 logprobs = 8;
bool echo = 9;
string stop = 10; // You can only represent Union[str, list] as a string.
float presence_penalty = 11;
float frequence_penalty = 12;
int32 best_of = 13;
map<string, float> logit_bias = 14; // Maps are represented as a pair of a key type and a value type.
}

// CompletionRespones are what's returned by the gRPC service
message CompletionResponse {
string completion = 1;
}

service GenerateService {
rpc Complete (CompletionRequest) returns (CompletionResponse);
}