-
-
Notifications
You must be signed in to change notification settings - Fork 359
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
WIP register_models() plugin hook, refs #53
- Loading branch information
Showing
6 changed files
with
245 additions
and
42 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
from dataclasses import dataclass | ||
from typing import Any, Dict, Generator, Optional, Set | ||
from abc import ABC, abstractmethod | ||
from pydantic import BaseModel | ||
|
||
|
||
@dataclass | ||
class Prompt: | ||
prompt: str | ||
model: "Model" | ||
system: Optional[str] | ||
prompt_json: Optional[str] | ||
options: Dict[str, Any] | ||
|
||
def __init__(self, prompt, model, system=None, prompt_json=None, options=None): | ||
self.prompt = prompt | ||
self.model = model | ||
self.system = system | ||
self.prompt_json = prompt_json | ||
self.options = options or {} | ||
|
||
|
||
class OptionsError(Exception): | ||
pass | ||
|
||
|
||
class Response(ABC): | ||
def __init__(self, prompt: Prompt): | ||
self.prompt = prompt | ||
self._chunks = [] | ||
self._debug = {} | ||
self._done = False | ||
|
||
def __iter__(self): | ||
if self._done: | ||
return self._chunks | ||
for chunk in self.iter_prompt(): | ||
yield chunk | ||
self._chunks.append(chunk) | ||
self._done = True | ||
|
||
@abstractmethod | ||
def iter_prompt(self) -> Generator[str, None, None]: | ||
pass | ||
|
||
def _force(self): | ||
if not self._done: | ||
list(self) | ||
|
||
def text(self): | ||
self._force() | ||
return "".join(self._chunks) | ||
|
||
|
||
class Model(ABC): | ||
model_id: str | ||
|
||
class Options(BaseModel): | ||
class Config: | ||
extra = "forbid" | ||
|
||
def prompt(self, prompt, system=None, stream=True, **options): | ||
return self.execute( | ||
Prompt(prompt, system=system, model=self, options=self.Options(**options)), | ||
stream=stream, | ||
) | ||
|
||
@abstractmethod | ||
def execute(self, prompt: Prompt, stream: bool = True) -> Response: | ||
pass | ||
|
||
@abstractmethod | ||
def __str__(self) -> str: | ||
pass | ||
|
||
|
||
@dataclass | ||
class ModelWithAliases: | ||
model: Model | ||
aliases: Set[str] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
from . import Model, Prompt, OptionsError, Response, hookimpl | ||
from typing import Optional | ||
import openai | ||
|
||
|
||
@hookimpl | ||
def register_models(register): | ||
register(Chat("gpt-3.5-turbo"), aliases=("3.5", "chatgpt")) | ||
register(Chat("gpt-3.5-turbo-16k"), aliases=("chatgpt-16k", "3.5-16k")) | ||
register(Chat("gpt-4"), aliases=("4", "gpt4")) | ||
register(Chat("gpt-4-32k"), aliases=("4-32k",)) | ||
|
||
|
||
class ChatResponse(Response): | ||
def __init__(self, prompt, stream): | ||
self.prompt = prompt | ||
self.stream = stream | ||
super().__init__(prompt) | ||
|
||
def iter_prompt(self): | ||
messages = [] | ||
if self.prompt.system: | ||
messages.append({"role": "system", "content": self.prompt.system}) | ||
messages.append({"role": "user", "content": self.prompt.prompt}) | ||
if self.stream: | ||
for chunk in openai.ChatCompletion.create( | ||
model=self.prompt.model.model_id, | ||
messages=messages, | ||
stream=True, | ||
): | ||
self._debug["model"] = chunk.model | ||
content = chunk["choices"][0].get("delta", {}).get("content") | ||
if content is not None: | ||
yield content | ||
self._done = True | ||
else: | ||
response = openai.ChatCompletion.create( | ||
model=self.prompt.model.model_id, | ||
messages=messages, | ||
stream=False, | ||
) | ||
self._debug["model"] = response.model | ||
self._debug["usage"] = response.usage | ||
content = response.choices[0].message.content | ||
self._done = True | ||
yield content | ||
|
||
|
||
class Chat(Model): | ||
def __init__(self, model_id, stream=True): | ||
self.model_id = model_id | ||
self.stream = stream | ||
|
||
def execute(self, prompt: Prompt, stream: bool = True) -> ChatResponse: | ||
return ChatResponse(prompt, stream) | ||
|
||
def __str__(self): | ||
return "OpenAI Chat: {}".format(self.model_id) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters