diff --git a/requirements.txt b/requirements.txt index ece754e..c6fd132 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,2 +1,2 @@ -steamship~=2.17.0 -langchain==0.0.168 \ No newline at end of file +steamship~=2.17.4 +langchain==0.0.200 \ No newline at end of file diff --git a/src/steamship_langchain/chat_models/__init__.py b/src/steamship_langchain/chat_models/__init__.py index e69de29..fccc818 100644 --- a/src/steamship_langchain/chat_models/__init__.py +++ b/src/steamship_langchain/chat_models/__init__.py @@ -0,0 +1,3 @@ +from steamship_langchain.chat_models.openai import ChatOpenAI + +__all__ = ["ChatOpenAI"] diff --git a/src/steamship_langchain/chat_models/openai.py b/src/steamship_langchain/chat_models/openai.py index 50b24b3..93f1808 100644 --- a/src/steamship_langchain/chat_models/openai.py +++ b/src/steamship_langchain/chat_models/openai.py @@ -1,34 +1,44 @@ """OpenAI chat wrapper.""" from __future__ import annotations +import json import logging from typing import Any, Dict, Generator, List, Mapping, Optional, Tuple import tiktoken +from langchain.callbacks.manager import CallbackManagerForLLMRun from langchain.chat_models.base import BaseChatModel +from langchain.chat_models.openai import ChatOpenAI from langchain.schema import ( AIMessage, BaseMessage, ChatGeneration, ChatMessage, ChatResult, + FunctionMessage, HumanMessage, LLMResult, SystemMessage, ) -from pydantic import Extra, Field, root_validator +from pydantic import Extra, Field, ValidationError from steamship import Block, File, MimeTypes, PluginInstance, Steamship, Tag -from steamship.data.tags.tag_constants import RoleTag, TagKind +from steamship.data.tags.tag_constants import TagKind logger = logging.getLogger(__file__) -def _convert_dict_to_message(_dict: dict) -> BaseMessage: +def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage: role = _dict["role"] if role == "user": return HumanMessage(content=_dict["content"]) elif role == "assistant": - return AIMessage(content=_dict["content"]) + content = _dict["content"] + if "function_call" in content: + try: + return AIMessage(content="", additional_kwargs=json.loads(content)) + except Exception: + pass + return AIMessage(content=content) elif role == "system": return SystemMessage(content=_dict["content"]) else: @@ -42,8 +52,16 @@ def _convert_message_to_dict(message: BaseMessage) -> dict: message_dict = {"role": "user", "content": message.content} elif isinstance(message, AIMessage): message_dict = {"role": "assistant", "content": message.content} + if "function_call" in message.additional_kwargs: + message_dict["function_call"] = message.additional_kwargs["function_call"] elif isinstance(message, SystemMessage): message_dict = {"role": "system", "content": message.content} + elif isinstance(message, FunctionMessage): + message_dict = { + "role": "function", + "content": message.content, + "name": message.name, + } else: raise ValueError(f"Got unknown type {message}") if "name" in message.additional_kwargs: @@ -51,7 +69,7 @@ def _convert_message_to_dict(message: BaseMessage) -> dict: return message_dict -class ChatOpenAI(BaseChatModel): +class ChatOpenAI(ChatOpenAI, BaseChatModel): """Wrapper around OpenAI Chat large language models. To use, you should have the ``openai`` python package installed, and the @@ -68,7 +86,7 @@ class ChatOpenAI(BaseChatModel): """ client: Any #: :meta private: - model_name: str = "gpt-3.5-turbo" + model_name: str = "gpt-3.5-turbo-0613" """Model name to use.""" temperature: float = 0.7 """What sampling temperature to use.""" @@ -95,11 +113,30 @@ class Config: def __init__( self, client: Steamship, - model_name: str = "gpt-3.5-turbo", + model_name: str = "gpt-3.5-turbo-0613", moderate_output: bool = True, **kwargs, ): - super().__init__(client=client, model_name=model_name, **kwargs) + try: + + class OpenAI(object): + class ChatCompletion: + pass + + import sys + + sys.modules["openai"] = OpenAI + + dummy_api_key = False + if "openai_api_key" not in kwargs: + kwargs["openai_api_key"] = "DUMMY" + dummy_api_key = True + super().__init__(client=client, model_name=model_name, **kwargs) + if dummy_api_key: + self.openai_api_key = None + except ValidationError as e: + print(e) + self.client = client plugin_config = {"model": self.model_name, "moderate_output": moderate_output} if self.openai_api_key: plugin_config["openai_api_key"] = self.openai_api_key @@ -122,14 +159,6 @@ def __init__( fetch_if_exists=True, ) - @classmethod - @root_validator() - def validate_environment(cls, values: Dict) -> Dict: - """Validate that api key and python package exists in environment.""" - if values["n"] < 1: - raise ValueError("n must be at least 1.") - return values - @property def _default_params(self) -> Dict[str, Any]: """Get the default parameters for calling OpenAI API.""" @@ -154,12 +183,16 @@ def _complete(self, messages: [Dict[str, str]], **params) -> List[BaseMessage]: for msg in messages: role = msg.get("role", "user") content = msg.get("content", "") + name = msg.get("name", "") if len(content) > 0: - role_tag = RoleTag(role) + tags = [Tag(kind=TagKind.ROLE, name=role)] + if name: + tags.append(Tag(kind="name", name=name)) + blocks.append( Block( text=content, - tags=[Tag(kind=TagKind.ROLE, name=role_tag)], + tags=tags, mime_type=MimeTypes.TXT, ) ) @@ -169,14 +202,24 @@ def _complete(self, messages: [Dict[str, str]], **params) -> List[BaseMessage]: generate_task.wait() return [ - _convert_dict_to_message({"content": block.text, "role": RoleTag.USER.value}) + _convert_dict_to_message( + { + "content": block.text, + "role": [tag for tag in block.tags if tag.kind == TagKind.ROLE.value][0].name, + } + ) for block in generate_task.output.blocks ] def _generate( - self, messages: List[BaseMessage], stop: Optional[List[str]] = None + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, ) -> ChatResult: message_dicts, params = self._create_message_dicts(messages, stop) + params = {**params, **kwargs} messages = self._complete(messages=message_dicts, **params) return ChatResult( generations=[ChatGeneration(message=message) for message in messages],