Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

集成Yi模型 #343

Merged
merged 10 commits into from
Aug 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 19 additions & 18 deletions README.md

Large diffs are not rendered by default.

37 changes: 19 additions & 18 deletions README_ZH.md

Large diffs are not rendered by default.

9 changes: 5 additions & 4 deletions docs/sphinx_doc/en/source/tutorial/203-model.md
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ In the current AgentScope, the supported `model_type` types, the corresponding

| API | Task | Model Wrapper | `model_type` | Some Supported Models |
|------------------------|-----------------|---------------------------------------------------------------------------------------------------------------------------------|-------------------------------|--------------------------------------------------|
| OpenAI API | Chat | [`OpenAIChatWrapper`](https://github.com/modelscope/agentscope/blob/main/src/agentscope/models/openai_model.py) | `"openai_chat"` | gpt-4, gpt-3.5-turbo, ... |
| OpenAI API | Chat | [`OpenAIChatWrapper`](https://github.com/modelscope/agentscope/blob/main/src/agentscope/models/openai_model.py) | `"openai_chat"` | gpt-4, gpt-3.5-turbo, ... |
| | Embedding | [`OpenAIEmbeddingWrapper`](https://github.com/modelscope/agentscope/blob/main/src/agentscope/models/openai_model.py) | `"openai_embedding"` | text-embedding-ada-002, ... |
| | DALL·E | [`OpenAIDALLEWrapper`](https://github.com/modelscope/agentscope/blob/main/src/agentscope/models/openai_model.py) | `"openai_dall_e"` | dall-e-2, dall-e-3 |
| DashScope API | Chat | [`DashScopeChatWrapper`](https://github.com/modelscope/agentscope/blob/main/src/agentscope/models/dashscope_model.py) | `"dashscope_chat"` | qwen-plus, qwen-max, ... |
Expand All @@ -83,12 +83,13 @@ In the current AgentScope, the supported `model_type` types, the corresponding
| | Multimodal | [`DashScopeMultiModalWrapper`](https://github.com/modelscope/agentscope/blob/main/src/agentscope/models/dashscope_model.py) | `"dashscope_multimodal"` | qwen-vl-plus, qwen-vl-max, qwen-audio-turbo, ... |
| Gemini API | Chat | [`GeminiChatWrapper`](https://github.com/modelscope/agentscope/blob/main/src/agentscope/models/gemini_model.py) | `"gemini_chat"` | gemini-pro, ... |
| | Embedding | [`GeminiEmbeddingWrapper`](https://github.com/modelscope/agentscope/blob/main/src/agentscope/models/gemini_model.py) | `"gemini_embedding"` | models/embedding-001, ... |
| ZhipuAI API | Chat | [`ZhipuAIChatWrapper`](https://github.com/modelscope/agentscope/blob/main/src/agentscope/models/zhipu_model.py) | `"zhipuai_chat"` | glm4, ... |
| | Embedding | [`ZhipuAIEmbeddingWrapper`](https://github.com/modelscope/agentscope/blob/main/src/agentscope/models/zhipu_model.py) | `"zhipuai_embedding"` | embedding-2, ... |
| ZhipuAI API | Chat | [`ZhipuAIChatWrapper`](https://github.com/modelscope/agentscope/blob/main/src/agentscope/models/zhipu_model.py) | `"zhipuai_chat"` | glm4, ... |
| | Embedding | [`ZhipuAIEmbeddingWrapper`](https://github.com/modelscope/agentscope/blob/main/src/agentscope/models/zhipu_model.py) | `"zhipuai_embedding"` | embedding-2, ... |
| ollama | Chat | [`OllamaChatWrapper`](https://github.com/modelscope/agentscope/blob/main/src/agentscope/models/ollama_model.py) | `"ollama_chat"` | llama2, ... |
| | Embedding | [`OllamaEmbeddingWrapper`](https://github.com/modelscope/agentscope/blob/main/src/agentscope/models/ollama_model.py) | `"ollama_embedding"` | llama2, ... |
| | Generation | [`OllamaGenerationWrapper`](https://github.com/modelscope/agentscope/blob/main/src/agentscope/models/ollama_model.py) | `"ollama_generate"` | llama2, ... |
| LiteLLM API | Chat | [`LiteLLMChatWrapper`](https://github.com/modelscope/agentscope/blob/main/src/agentscope/models/litellm_model.py) | `"litellm_chat"` | - |
| LiteLLM API | Chat | [`LiteLLMChatWrapper`](https://github.com/modelscope/agentscope/blob/main/src/agentscope/models/litellm_model.py) | `"litellm_chat"` | - |
| Yi API | Chat | [`YiChatWrapper`](https://github.com/modelscope/agentscope/blob/main/src/agentscope/models/yi_model.py) | `"yi_chat"` | yi-large, yi-medium, ... |
| Post Request based API | - | [`PostAPIModelWrapperBase`](https://github.com/modelscope/agentscope/blob/main/src/agentscope/models/post_model.py) | `"post_api"` | - |
| | Chat | [`PostAPIChatWrapper`](https://github.com/modelscope/agentscope/blob/main/src/agentscope/models/post_model.py) | `"post_api_chat"` | meta-llama/Meta-Llama-3-8B-Instruct, ... |
| | Image Synthesis | [`PostAPIDALLEWrapper`](https://github.com/modelscope/agentscope/blob/main/src/agentscope/models/post_model.py) | `post_api_dall_e` | - | |
Expand Down
1 change: 1 addition & 0 deletions docs/sphinx_doc/zh_CN/source/tutorial/203-model.md
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ API如下:
| | Embedding | [`OllamaEmbeddingWrapper`](https://github.com/modelscope/agentscope/blob/main/src/agentscope/models/ollama_model.py) | `"ollama_embedding"` | llama2, ... |
| | Generation | [`OllamaGenerationWrapper`](https://github.com/modelscope/agentscope/blob/main/src/agentscope/models/ollama_model.py) | `"ollama_generate"` | llama2, ... |
| LiteLLM API | Chat | [`LiteLLMChatWrapper`](https://github.com/modelscope/agentscope/blob/main/src/agentscope/models/litellm_model.py) | `"litellm_chat"` | - |
| Yi API | Chat | [`YiChatWrapper`](https://github.com/modelscope/agentscope/blob/main/src/agentscope/models/yi_model.py) | `"yi_chat"` | yi-large, yi-medium, ... |
| Post Request based API | - | [`PostAPIModelWrapperBase`](https://github.com/modelscope/agentscope/blob/main/src/agentscope/models/post_model.py) | `"post_api"` | - |
| | Chat | [`PostAPIChatWrapper`](https://github.com/modelscope/agentscope/blob/main/src/agentscope/models/post_model.py) | `"post_api_chat"` | meta-llama/Meta-Llama-3-8B-Instruct, ... |
| | Image Synthesis | [`PostAPIDALLEWrapper`](https://github.com/modelscope/agentscope/blob/main/src/agentscope/models/post_model.py) | `post_api_dall_e` | - | |
Expand Down
11 changes: 11 additions & 0 deletions examples/model_configs_template/yi_chat_template.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
[
{
"config_name": "yi_yi-large",
"model_type": "yi_chat",
"model_name": "yi-large",
"api_key": "{your_api_key}",
"temperature": 0.3,
"top_p": 0.9,
"max_tokens": 1000
}
]
5 changes: 4 additions & 1 deletion src/agentscope/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,9 @@
from .litellm_model import (
LiteLLMChatWrapper,
)

from .yi_model import (
YiChatWrapper,
)

__all__ = [
"ModelWrapperBase",
Expand All @@ -61,6 +63,7 @@
"ZhipuAIChatWrapper",
"ZhipuAIEmbeddingWrapper",
"LiteLLMChatWrapper",
"YiChatWrapper",
]


Expand Down
4 changes: 2 additions & 2 deletions src/agentscope/models/openai_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ def __init__(

def __call__(
self,
messages: list,
messages: list[dict],
stream: Optional[bool] = None,
**kwargs: Any,
) -> ModelResponse:
Expand Down Expand Up @@ -331,7 +331,7 @@ def _save_model_invocation_and_update_monitor(
response=response,
)

usage = response.get("usage")
usage = response.get("usage", None)
if usage is not None:
self.monitor.update_text_and_embedding_tokens(
model_name=self.model_name,
Expand Down
292 changes: 292 additions & 0 deletions src/agentscope/models/yi_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,292 @@
# -*- coding: utf-8 -*-
"""Model wrapper for Yi models"""
import json
from typing import (
List,
Union,
Sequence,
Optional,
Generator,
)

import requests

from ._model_utils import (
_verify_text_content_in_openai_message_response,
_verify_text_content_in_openai_delta_response,
)
from .model import ModelWrapperBase, ModelResponse
from ..message import Msg


class YiChatWrapper(ModelWrapperBase):
"""The model wrapper for Yi Chat API.

Response:
- From https://platform.lingyiwanwu.com/docs

```json
{
"id": "cmpl-ea89ae83",
"object": "chat.completion",
"created": 5785971,
"model": "yi-large-rag",
"usage": {
"completion_tokens": 113,
"prompt_tokens": 896,
"total_tokens": 1009
},
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": "Today in Los Angeles, the weather ...",
},
"finish_reason": "stop"
}
]
}
```
"""

model_type: str = "yi_chat"

def __init__(
self,
config_name: str,
model_name: str,
api_key: str,
max_tokens: Optional[int] = None,
top_p: float = 0.9,
temperature: float = 0.3,
stream: bool = False,
) -> None:
"""Initialize the Yi chat model wrapper.

Args:
config_name (`str`):
The name of the configuration to use.
model_name (`str`):
The name of the model to use, e.g. yi-large, yi-medium, etc.
api_key (`str`):
The API key for the Yi API.
max_tokens (`Optional[int]`, defaults to `None`):
The maximum number of tokens to generate, defaults to `None`.
top_p (`float`, defaults to `0.9`):
The randomness parameters in the range [0, 1].
temperature (`float`, defaults to `0.3`):
The temperature parameter in the range [0, 2].
stream (`bool`, defaults to `False`):
Whether to stream the response or not.
"""

super().__init__(config_name, model_name)

if top_p > 1 or top_p < 0:
raise ValueError(
f"The `top_p` parameter must be in the range [0, 1], but got "
f"{top_p} instead.",
)

if temperature < 0 or temperature > 2:
raise ValueError(
f"The `temperature` parameter must be in the range [0, 2], "
f"but got {temperature} instead.",
)

self.api_key = api_key
self.max_tokens = max_tokens
self.top_p = top_p
self.temperature = temperature
self.stream = stream

def __call__(
DavdGao marked this conversation as resolved.
Show resolved Hide resolved
self,
messages: list[dict],
stream: Optional[bool] = None,
) -> ModelResponse:
"""Invoke the Yi Chat API by sending a list of messages."""

# Checking messages
if not isinstance(messages, list):
raise ValueError(
f"Yi `messages` field expected type `list`, "
f"got `{type(messages)}` instead.",
)

if not all("role" in msg and "content" in msg for msg in messages):
raise ValueError(
"Each message in the 'messages' list must contain a 'role' "
"and 'content' key for Yi API.",
)

if stream is None:
stream = self.stream

# Forward to generate response
kwargs = {
"url": "https://api.lingyiwanwu.com/v1/chat/completions",
"json": {
"model": self.model_name,
"messages": messages,
"temperature": self.temperature,
"max_tokens": self.max_tokens,
"top_p": self.top_p,
"stream": stream,
},
"headers": {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json",
},
}

response = requests.post(**kwargs)
response.raise_for_status()

if stream:

def generator() -> Generator[str, None, None]:
text = ""
last_chunk = {}
for line in response.iter_lines():
if line:
line_str = line.decode("utf-8").strip()

# Remove prefix "data: " if exists
json_str = line_str.removeprefix("data: ")

# The last response is "data: [DONE]"
if json_str == "[DONE]":
continue

try:
chunk = json.loads(json_str)
if _verify_text_content_in_openai_delta_response(
chunk,
):
text += chunk["choices"][0]["delta"]["content"]
yield text
last_chunk = chunk

except json.decoder.JSONDecodeError as e:
raise json.decoder.JSONDecodeError(
f"Invalid JSON: {json_str}",
e.doc,
e.pos,
) from e

# In Yi Chat API, the last valid chunk will save all the text
# in this message
self._save_model_invocation_and_update_monitor(
kwargs,
last_chunk,
)

return ModelResponse(
stream=generator(),
)
else:
response = response.json()
self._save_model_invocation_and_update_monitor(
kwargs,
response,
)

# Re-use the openai response checking function
if _verify_text_content_in_openai_message_response(response):
return ModelResponse(
text=response["choices"][0]["message"]["content"],
raw=response,
)
else:
raise RuntimeError(
f"Invalid response from Yi Chat API: {response}",
)

def format(
self,
*args: Union[Msg, Sequence[Msg]],
) -> List[dict]:
"""Format the messages into the required format of Yi Chat API.

Note this strategy maybe not suitable for all scenarios,
and developers are encouraged to implement their own prompt
engineering strategies.

The following is an example:

.. code-block:: python

prompt1 = model.format(
Msg("system", "You're a helpful assistant", role="system"),
Msg("Bob", "Hi, how can I help you?", role="assistant"),
Msg("user", "What's the date today?", role="user")
)

The prompt will be as follows:

.. code-block:: python

# prompt1
[
{
"role": "user",
"content": (
"You're a helpful assistant\\n"
"\\n"
"## Conversation History\\n"
"Bob: Hi, how can I help you?\\n"
"user: What's the date today?"
)
}
]

Args:
args (`Union[Msg, Sequence[Msg]]`):
The input arguments to be formatted, where each argument
should be a `Msg` object, or a list of `Msg` objects.
In distribution, placeholder is also allowed.

Returns:
`List[dict]`:
The formatted messages.
"""

# TODO: Support Vision model
if self.model_name == "yi-vision":
raise NotImplementedError(
"Yi Vision model is not supported in the current version, "
"please format the messages manually.",
)

return ModelWrapperBase.format_for_common_chat_models(*args)

def _save_model_invocation_and_update_monitor(
self,
kwargs: dict,
response: dict,
) -> None:
"""Save model invocation and update the monitor accordingly.

Args:
kwargs (`dict`):
The keyword arguments used in model invocation
response (`dict`):
The response from model API
"""
self._save_model_invocation(
arguments=kwargs,
response=response,
)

usage = response.get("usage", None)
if usage is not None:
prompt_tokens = usage.get("prompt_tokens", 0)
completion_tokens = usage.get("completion_tokens", 0)

self.monitor.update_text_and_embedding_tokens(
model_name=self.model_name,
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
)