Skip to content

Commit

Permalink
Merge pull request #92 from defenseunicorns/protobuf
Browse files Browse the repository at this point in the history
first proto example
  • Loading branch information
gerred authored Jun 14, 2023
2 parents 9570cd5 + 5b4a763 commit 0a9b7a2
Show file tree
Hide file tree
Showing 7 changed files with 257 additions and 0 deletions.
21 changes: 21 additions & 0 deletions docs/Protobuf.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# Create new Model


## Generate a stubbed out Python server

```shell
pip install grpc-tools
```


To create the generated files use this call
```shell
python -m grpc_tools.protoc --proto_path=proto generate.proto --python_out=llms/repeater --pyi_out=llms/repeater --grpc_python_out=llms/repeater
```

which creates the following files
* generate_pb2_grpc.py
* generate_pb2.py
* generate_pb2.pyi

and then those objects are used in the `repeater.py` to implement the grpc service
33 changes: 33 additions & 0 deletions llms/repeater/generate_pb2.py

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

51 changes: 51 additions & 0 deletions llms/repeater/generate_pb2.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
from google.protobuf.internal import containers as _containers
from google.protobuf import descriptor as _descriptor
from google.protobuf import message as _message
from typing import ClassVar as _ClassVar, Mapping as _Mapping, Optional as _Optional

DESCRIPTOR: _descriptor.FileDescriptor

class CompletionRequest(_message.Message):
__slots__ = ["best_of", "echo", "frequence_penalty", "logit_bias", "logprobs", "max_tokens", "n", "presence_penalty", "prompt", "stop", "stream", "suffix", "temperature", "top_p"]
class LogitBiasEntry(_message.Message):
__slots__ = ["key", "value"]
KEY_FIELD_NUMBER: _ClassVar[int]
VALUE_FIELD_NUMBER: _ClassVar[int]
key: str
value: float
def __init__(self, key: _Optional[str] = ..., value: _Optional[float] = ...) -> None: ...
BEST_OF_FIELD_NUMBER: _ClassVar[int]
ECHO_FIELD_NUMBER: _ClassVar[int]
FREQUENCE_PENALTY_FIELD_NUMBER: _ClassVar[int]
LOGIT_BIAS_FIELD_NUMBER: _ClassVar[int]
LOGPROBS_FIELD_NUMBER: _ClassVar[int]
MAX_TOKENS_FIELD_NUMBER: _ClassVar[int]
N_FIELD_NUMBER: _ClassVar[int]
PRESENCE_PENALTY_FIELD_NUMBER: _ClassVar[int]
PROMPT_FIELD_NUMBER: _ClassVar[int]
STOP_FIELD_NUMBER: _ClassVar[int]
STREAM_FIELD_NUMBER: _ClassVar[int]
SUFFIX_FIELD_NUMBER: _ClassVar[int]
TEMPERATURE_FIELD_NUMBER: _ClassVar[int]
TOP_P_FIELD_NUMBER: _ClassVar[int]
best_of: int
echo: bool
frequence_penalty: float
logit_bias: _containers.ScalarMap[str, float]
logprobs: int
max_tokens: int
n: int
presence_penalty: float
prompt: str
stop: str
stream: bool
suffix: str
temperature: float
top_p: float
def __init__(self, prompt: _Optional[str] = ..., suffix: _Optional[str] = ..., max_tokens: _Optional[int] = ..., temperature: _Optional[float] = ..., top_p: _Optional[float] = ..., n: _Optional[int] = ..., stream: bool = ..., logprobs: _Optional[int] = ..., echo: bool = ..., stop: _Optional[str] = ..., presence_penalty: _Optional[float] = ..., frequence_penalty: _Optional[float] = ..., best_of: _Optional[int] = ..., logit_bias: _Optional[_Mapping[str, float]] = ...) -> None: ...

class CompletionResponse(_message.Message):
__slots__ = ["completion"]
COMPLETION_FIELD_NUMBER: _ClassVar[int]
completion: str
def __init__(self, completion: _Optional[str] = ...) -> None: ...
66 changes: 66 additions & 0 deletions llms/repeater/generate_pb2_grpc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
"""Client and server classes corresponding to protobuf-defined services."""
import grpc

import generate_pb2 as generate__pb2


class GenerateServiceStub(object):
"""Missing associated documentation comment in .proto file."""

def __init__(self, channel):
"""Constructor.
Args:
channel: A grpc.Channel.
"""
self.Complete = channel.unary_unary(
'/leapfrog.GenerateService/Complete',
request_serializer=generate__pb2.CompletionRequest.SerializeToString,
response_deserializer=generate__pb2.CompletionResponse.FromString,
)


class GenerateServiceServicer(object):
"""Missing associated documentation comment in .proto file."""

def Complete(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')


def add_GenerateServiceServicer_to_server(servicer, server):
rpc_method_handlers = {
'Complete': grpc.unary_unary_rpc_method_handler(
servicer.Complete,
request_deserializer=generate__pb2.CompletionRequest.FromString,
response_serializer=generate__pb2.CompletionResponse.SerializeToString,
),
}
generic_handler = grpc.method_handlers_generic_handler(
'leapfrog.GenerateService', rpc_method_handlers)
server.add_generic_rpc_handlers((generic_handler,))


# This class is part of an EXPERIMENTAL API.
class GenerateService(object):
"""Missing associated documentation comment in .proto file."""

@staticmethod
def Complete(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(request, target, '/leapfrog.GenerateService/Complete',
generate__pb2.CompletionRequest.SerializeToString,
generate__pb2.CompletionResponse.FromString,
options, channel_credentials,
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
32 changes: 32 additions & 0 deletions llms/repeater/repeater.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from concurrent import futures
import grpc
import generate_pb2
import generate_pb2_grpc


class Repeater(object):
def Complete(self, request, context):

result = request.prompt # just returns what's provided
return generate_pb2.CompletionResponse(completion=result)


def serve():
# Create a gRPC server
server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
generate_pb2_grpc.add_GenerateServiceServicer_to_server(Repeater(), server)

# Listen on port 50051
print('Starting server. Listening on port 50051.')
server.add_insecure_port('[::]:50051')
server.start()

# Keep thread alive
try:
while True:
pass
except KeyboardInterrupt:
server.stop(0)

if __name__ == "__main__":
serve()
24 changes: 24 additions & 0 deletions llms/repeater/test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import grpc
import generate_pb2
import generate_pb2_grpc

def run():
# Set up a channel to the server
with grpc.insecure_channel('localhost:50051') as channel:
# Instantiate a stub (client)
stub = generate_pb2_grpc.GenerateServiceStub(channel)

# Create a request
request = generate_pb2.CompletionRequest(
prompt="Hello, Chatbot!",
# add other parameters as necessary
)

# Make a call to the server and get a response
response = stub.Complete(request)

# Print the response
print("Received response: ", response.completion)

if __name__ == "__main__":
run()
30 changes: 30 additions & 0 deletions proto/generate.proto
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
syntax = "proto3";

package leapfrog;

// CompletionRequest is the payload to request completion
message CompletionRequest {
string prompt = 1;
string suffix = 2;
int32 max_tokens = 3;
float temperature = 4;
float top_p = 5;
int32 n = 6;
bool stream = 7;
int32 logprobs = 8;
bool echo = 9;
string stop = 10; // You can only represent Union[str, list] as a string.
float presence_penalty = 11;
float frequence_penalty = 12;
int32 best_of = 13;
map<string, float> logit_bias = 14; // Maps are represented as a pair of a key type and a value type.
}

// CompletionRespones are what's returned by the gRPC service
message CompletionResponse {
string completion = 1;
}

service GenerateService {
rpc Complete (CompletionRequest) returns (CompletionResponse);
}

0 comments on commit 0a9b7a2

Please sign in to comment.