Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ec/support functions #53

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
steamship~=2.17.0
langchain==0.0.168
steamship~=2.17.4
langchain==0.0.200
3 changes: 3 additions & 0 deletions src/steamship_langchain/chat_models/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from steamship_langchain.chat_models.openai import ChatOpenAI

__all__ = ["ChatOpenAI"]
83 changes: 63 additions & 20 deletions src/steamship_langchain/chat_models/openai.py
Original file line number Diff line number Diff line change
@@ -1,34 +1,44 @@
"""OpenAI chat wrapper."""
from __future__ import annotations

import json
import logging
from typing import Any, Dict, Generator, List, Mapping, Optional, Tuple

import tiktoken
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.chat_models.base import BaseChatModel
from langchain.chat_models.openai import ChatOpenAI
from langchain.schema import (
AIMessage,
BaseMessage,
ChatGeneration,
ChatMessage,
ChatResult,
FunctionMessage,
HumanMessage,
LLMResult,
SystemMessage,
)
from pydantic import Extra, Field, root_validator
from pydantic import Extra, Field, ValidationError
from steamship import Block, File, MimeTypes, PluginInstance, Steamship, Tag
from steamship.data.tags.tag_constants import RoleTag, TagKind
from steamship.data.tags.tag_constants import TagKind

logger = logging.getLogger(__file__)


def _convert_dict_to_message(_dict: dict) -> BaseMessage:
def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
role = _dict["role"]
if role == "user":
return HumanMessage(content=_dict["content"])
elif role == "assistant":
return AIMessage(content=_dict["content"])
content = _dict["content"]
if "function_call" in content:
try:
return AIMessage(content="", additional_kwargs=json.loads(content))
except Exception:
pass
return AIMessage(content=content)
elif role == "system":
return SystemMessage(content=_dict["content"])
else:
Expand All @@ -42,16 +52,24 @@ def _convert_message_to_dict(message: BaseMessage) -> dict:
message_dict = {"role": "user", "content": message.content}
elif isinstance(message, AIMessage):
message_dict = {"role": "assistant", "content": message.content}
if "function_call" in message.additional_kwargs:
message_dict["function_call"] = message.additional_kwargs["function_call"]
elif isinstance(message, SystemMessage):
message_dict = {"role": "system", "content": message.content}
elif isinstance(message, FunctionMessage):
message_dict = {
"role": "function",
"content": message.content,
"name": message.name,
}
else:
raise ValueError(f"Got unknown type {message}")
if "name" in message.additional_kwargs:
message_dict["name"] = message.additional_kwargs["name"]
return message_dict


class ChatOpenAI(BaseChatModel):
class ChatOpenAI(ChatOpenAI, BaseChatModel):
"""Wrapper around OpenAI Chat large language models.

To use, you should have the ``openai`` python package installed, and the
Expand All @@ -68,7 +86,7 @@ class ChatOpenAI(BaseChatModel):
"""

client: Any #: :meta private:
model_name: str = "gpt-3.5-turbo"
model_name: str = "gpt-3.5-turbo-0613"
"""Model name to use."""
temperature: float = 0.7
"""What sampling temperature to use."""
Expand All @@ -95,11 +113,30 @@ class Config:
def __init__(
self,
client: Steamship,
model_name: str = "gpt-3.5-turbo",
model_name: str = "gpt-3.5-turbo-0613",
moderate_output: bool = True,
**kwargs,
):
super().__init__(client=client, model_name=model_name, **kwargs)
try:

class OpenAI(object):
class ChatCompletion:
pass

import sys

sys.modules["openai"] = OpenAI

dummy_api_key = False
if "openai_api_key" not in kwargs:
kwargs["openai_api_key"] = "DUMMY"
dummy_api_key = True
super().__init__(client=client, model_name=model_name, **kwargs)
if dummy_api_key:
self.openai_api_key = None
except ValidationError as e:
print(e)
self.client = client
plugin_config = {"model": self.model_name, "moderate_output": moderate_output}
if self.openai_api_key:
plugin_config["openai_api_key"] = self.openai_api_key
Expand All @@ -122,14 +159,6 @@ def __init__(
fetch_if_exists=True,
)

@classmethod
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment."""
if values["n"] < 1:
raise ValueError("n must be at least 1.")
return values

@property
def _default_params(self) -> Dict[str, Any]:
"""Get the default parameters for calling OpenAI API."""
Expand All @@ -154,12 +183,16 @@ def _complete(self, messages: [Dict[str, str]], **params) -> List[BaseMessage]:
for msg in messages:
role = msg.get("role", "user")
content = msg.get("content", "")
name = msg.get("name", "")
if len(content) > 0:
role_tag = RoleTag(role)
tags = [Tag(kind=TagKind.ROLE, name=role)]
if name:
tags.append(Tag(kind="name", name=name))

blocks.append(
Block(
text=content,
tags=[Tag(kind=TagKind.ROLE, name=role_tag)],
tags=tags,
mime_type=MimeTypes.TXT,
)
)
Expand All @@ -169,14 +202,24 @@ def _complete(self, messages: [Dict[str, str]], **params) -> List[BaseMessage]:
generate_task.wait()

return [
_convert_dict_to_message({"content": block.text, "role": RoleTag.USER.value})
_convert_dict_to_message(
{
"content": block.text,
"role": [tag for tag in block.tags if tag.kind == TagKind.ROLE.value][0].name,
}
)
for block in generate_task.output.blocks
]

def _generate(
self, messages: List[BaseMessage], stop: Optional[List[str]] = None
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
message_dicts, params = self._create_message_dicts(messages, stop)
params = {**params, **kwargs}
messages = self._complete(messages=message_dicts, **params)
return ChatResult(
generations=[ChatGeneration(message=message) for message in messages],
Expand Down
Loading