Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(openai): dynamic model_type registration #704

Merged
merged 1 commit into from
Nov 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 11 additions & 5 deletions openllm-python/src/openllm/entrypoints/_openapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
example:
object: 'list'
data:
- id: meta-llama--Llama-2-13b-chat-hf
- id: __model_id__
object: model
created: 1686935002
owned_by: 'na'
Expand Down Expand Up @@ -69,7 +69,7 @@
content: You are a helpful assistant.
- role: user
content: Hello, I'm looking for a chatbot that can help me with my work.
model: meta-llama--Llama-2-13b-chat-hf
model: __model_id__
max_tokens: 256
temperature: 0.7
top_p: 0.43
Expand All @@ -83,7 +83,7 @@
content: You are a helpful assistant.
- role: user
content: Hello, I'm looking for a chatbot that can help me with my work.
model: meta-llama--Llama-2-13b-chat-hf
model: __model_id__
max_tokens: 256
temperature: 0.7
top_p: 0.43
Expand Down Expand Up @@ -206,7 +206,7 @@
summary: One-shot input example
value:
prompt: This is a test
model: meta-llama--Llama-2-13b-chat-hf
model: __model_id__
max_tokens: 256
temperature: 0.7
logprobs: 1
Expand All @@ -217,7 +217,7 @@
summary: Streaming input example
value:
prompt: This is a test
model: meta-llama--Llama-2-13b-chat-hf
model: __model_id__
max_tokens: 256
temperature: 0.7
top_p: 0.43
Expand Down Expand Up @@ -472,6 +472,12 @@
_SCHEMAS = {k[:-7].lower(): v for k, v in locals().items() if k.endswith('_SCHEMA')}


def apply_schema(func, **attrs):
for k, v in attrs.items():
func.__doc__ = func.__doc__.replace(k, v)
return func


def add_schema_definitions(func):
append_str = _SCHEMAS.get(func.__name__.lower(), '')
if not append_str:
Expand Down
1 change: 1 addition & 0 deletions openllm-python/src/openllm/entrypoints/_openapi.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ class OpenLLMSchemaGenerator:
def get_schema(self, routes: list[BaseRoute], mount_path: Optional[str] = ...) -> Dict[str, Any]: ...
def parse_docstring(self, func_or_method: Callable[P, Any]) -> Dict[str, Any]: ...

def apply_schema(func: Callable[P, Any], **attrs: Any) -> Callable[P, Any]: ...
def add_schema_definitions(func: Callable[P, Any]) -> Callable[P, Any]: ...
def append_schemas(
svc: Service, generated_schema: Dict[str, Any], tags_order: Literal['prepend', 'append'] = ..., inject: bool = ...
Expand Down
21 changes: 17 additions & 4 deletions openllm-python/src/openllm/entrypoints/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from openllm_core._schemas import SampleLogprobs
from openllm_core.utils import converter, gen_random_uuid

from ._openapi import add_schema_definitions, append_schemas, get_generator
from ._openapi import add_schema_definitions, append_schemas, apply_schema, get_generator
from ..protocol.openai import (
ChatCompletionRequest,
ChatCompletionResponse,
Expand Down Expand Up @@ -100,12 +100,25 @@ def create_logprobs(token_ids, id_logprobs, initial_text_offset=0, *, llm):


def mount_to_svc(svc, llm):
list_models.__doc__ = list_models.__doc__.replace('__model_id__', llm.llm_type)
completions.__doc__ = completions.__doc__.replace('__model_id__', llm.llm_type)
chat_completions.__doc__ = chat_completions.__doc__.replace('__model_id__', llm.llm_type)
app = Starlette(
debug=True,
routes=[
Route('/models', functools.partial(list_models, llm=llm), methods=['GET']),
Route('/completions', functools.partial(completions, llm=llm), methods=['POST']),
Route('/chat/completions', functools.partial(chat_completions, llm=llm), methods=['POST']),
Route(
'/models', functools.partial(apply_schema(list_models, __model_id__=llm.llm_type), llm=llm), methods=['GET']
),
Route(
'/completions',
functools.partial(apply_schema(completions, __model_id__=llm.llm_type), llm=llm),
methods=['POST'],
),
Route(
'/chat/completions',
functools.partial(apply_schema(chat_completions, __model_id__=llm.llm_type), llm=llm),
methods=['POST'],
),
Route('/schema', endpoint=lambda req: schemas.OpenAPIResponse(req), include_in_schema=False),
],
)
Expand Down