Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unify model implementation and clean up their code in AgentScope. #82

Merged
merged 8 commits into from
Mar 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/sphinx_doc/en/source/tutorial/203-model.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ which we will use to specify the model service when initializing an agent.
It corresponds to the `model_type` field in the `ModelWrapper` class in the source code.

```python
class OpenAIChatWrapper(OpenAIWrapper):
class OpenAIChatWrapper(OpenAIWrapperBase):
"""The model wrapper for OpenAI's chat API."""

model_type: str = "openai"
Expand Down
2 changes: 1 addition & 1 deletion docs/sphinx_doc/en/source/tutorial/206-prompt.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ The `PromptEngine` class provides a structured way to combine different componen
When creating an instance of `PromptEngine`, you can specify the target model and, optionally, the shrinking policy, the maximum length of the prompt, the prompt type, and a summarization model (could be the same as the target model).

```python
model = OpenAIWrapper(...)
model = OpenAIChatWrapper(...)
engine = PromptEngine(model)
```

Expand Down
2 changes: 1 addition & 1 deletion docs/sphinx_doc/zh_CN/source/tutorial/206-prompt.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ AgentScope中允许开发者按照自己的需求定制提示,同时提供了`
当创建 `PromptEngine` 的实例时,您可以指定目标模型,以及(可选的)缩减原则、提示的最大长度、提示类型和总结模型(可以与目标模型相同)。

```python
model = OpenAIWrapper(...)
model = OpenAIChatWrapper(...)
engine = PromptEngine(model)
```

Expand Down
8 changes: 4 additions & 4 deletions src/agentscope/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,14 @@

from loguru import logger

from .config import ModelConfig
from .config import _ModelConfig
from .model import ModelWrapperBase, ModelResponse
from .post_model import (
PostAPIModelWrapperBase,
PostAPIChatWrapper,
)
from .openai_model import (
OpenAIWrapper,
OpenAIWrapperBase,
OpenAIChatWrapper,
OpenAIDALLEWrapper,
OpenAIEmbeddingWrapper,
Expand All @@ -38,7 +38,7 @@
"ModelResponse",
"PostAPIModelWrapperBase",
"PostAPIChatWrapper",
"OpenAIWrapper",
"OpenAIWrapperBase",
"OpenAIChatWrapper",
"OpenAIDALLEWrapper",
"OpenAIEmbeddingWrapper",
Expand Down Expand Up @@ -156,7 +156,7 @@ def read_model_configs(
)
cfgs = configs

format_configs = ModelConfig.format_configs(configs=cfgs)
format_configs = _ModelConfig.format_configs(configs=cfgs)

# check if name is unique
for cfg in format_configs:
Expand Down
10 changes: 5 additions & 5 deletions src/agentscope/models/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from loguru import logger


class ModelConfig(dict):
class _ModelConfig(dict):
"""Base class for model config."""

__getattr__ = dict.__getitem__
Expand Down Expand Up @@ -46,14 +46,14 @@ def format_configs(
cls,
configs: Union[Sequence[dict], dict],
) -> Sequence:
"""Covert config dicts into a list of ModelConfig.
"""Covert config dicts into a list of _ModelConfig.

Args:
configs (Union[Sequence[dict], dict]): configs in dict format.

Returns:
Sequence[ModelConfig]: converted ModelConfig list.
Sequence[_ModelConfig]: converted ModelConfig list.
"""
if isinstance(configs, dict):
return [ModelConfig(**configs)]
return [ModelConfig(**cfg) for cfg in configs]
return [_ModelConfig(**configs)]
return [_ModelConfig(**cfg) for cfg in configs]
111 changes: 47 additions & 64 deletions src/agentscope/models/dashscope_model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
"""Model wrapper for DashScope models"""
from abc import ABC
from http import HTTPStatus
from typing import Any, Union

Expand All @@ -13,11 +14,9 @@
from .model import ModelWrapperBase, ModelResponse

from ..file_manager import file_manager
from ..utils.monitor import MonitorFactory
from ..constants import _DEFAULT_API_BUDGET


class DashScopeWrapper(ModelWrapperBase):
class DashScopeWrapperBase(ModelWrapperBase, ABC):
"""The model wrapper for DashScope API."""

def __init__(
Expand All @@ -26,7 +25,6 @@ def __init__(
model_name: str = None,
api_key: str = None,
generate_args: dict = None,
budget: float = _DEFAULT_API_BUDGET,
**kwargs: Any,
) -> None:
"""Initialize the DashScope wrapper.
Expand All @@ -41,44 +39,30 @@ def __init__(
generate_args (`dict`, default `None`):
The extra keyword arguments used in DashScope api generation,
e.g. `temperature`, `seed`.
budget (`float`, default `None`):
The total budget using this model. Set to `None` means no
limit.
"""
if model_name is None:
model_name = config_name
logger.warning("model_name is not set, use config_name instead.")
super().__init__(
config_name=config_name,
model_name=model_name,
generate_args=generate_args,
budget=budget,
**kwargs,
)

super().__init__(config_name=config_name)

if dashscope is None:
raise ImportError(
"Cannot find dashscope package in current python environment.",
)

self.model = model_name
self.model_name = model_name
self.generate_args = generate_args or {}

self.api_key = api_key
dashscope.api_key = self.api_key
self.max_length = None

# Set monitor accordingly
self.monitor = None
self._register_default_metrics()

def _register_default_metrics(self) -> None:
"""Register metrics to the monitor."""
raise NotImplementedError(
"The _register_default_metrics function is not Implemented.",
)


class DashScopeChatWrapper(DashScopeWrapper):
class DashScopeChatWrapper(DashScopeWrapperBase):
"""The model wrapper for DashScope's chat API."""

model_type: str = "dashscope_chat"
Expand All @@ -88,17 +72,20 @@ class DashScopeChatWrapper(DashScopeWrapper):
def _register_default_metrics(self) -> None:
# Set monitor accordingly
# TODO: set quota to the following metrics
self.monitor = MonitorFactory.get_monitor()
self.monitor.register(
self._metric("prompt_tokens", self.model),
self._metric("call_counter"),
metric_unit="times",
)
self.monitor.register(
self._metric("prompt_tokens"),
metric_unit="token",
)
self.monitor.register(
self._metric("completion_tokens", self.model),
self._metric("completion_tokens"),
metric_unit="token",
)
self.monitor.register(
self._metric("total_tokens", self.model),
self._metric("total_tokens"),
metric_unit="token",
)

Expand Down Expand Up @@ -169,7 +156,7 @@ def __call__(
messages = self._preprocess_role(messages)
# step3: forward to generate response
response = dashscope.Generation.call(
model=self.model,
model=self.model_name,
messages=messages,
result_format="message", # set the result to be "message" format.
**kwargs,
Expand All @@ -188,25 +175,21 @@ def __call__(
# step4: record the api invocation if needed
self._save_model_invocation(
arguments={
"model": self.model,
"model": self.model_name,
"messages": messages,
**kwargs,
},
response=response,
)

# step5: update monitor accordingly
try:
self.monitor.update(
{
"prompt_tokens": response.usage["input_tokens"],
"completion_tokens": response.usage["output_tokens"],
"total_tokens": response.usage["total_tokens"],
},
prefix=self.model,
)
except Exception as e:
logger.error(e)
# The metric names are unified for comparison
self.update_monitor(
call_counter=1,
prompt_tokens=response.usage["input_tokens"],
completion_tokens=response.usage["output_tokens"],
total_tokens=response.usage["total_tokens"],
)

# step6: return response
return ModelResponse(
Expand Down Expand Up @@ -243,17 +226,20 @@ def _preprocess_role(self, messages: list) -> list:
return messages


class DashScopeImageSynthesisWrapper(DashScopeWrapper):
class DashScopeImageSynthesisWrapper(DashScopeWrapperBase):
"""The model wrapper for DashScope Image Synthesis API."""

model_type: str = "dashscope_image_synthesis"

def _register_default_metrics(self) -> None:
# Set monitor accordingly
# TODO: set quota to the following metrics
self.monitor = MonitorFactory.get_monitor()
self.monitor.register(
self._metric("image_count", self.model),
self._metric("call_counter"),
metric_unit="times",
)
self.monitor.register(
self._metric("image_count"),
metric_unit="image",
)

Expand Down Expand Up @@ -300,7 +286,7 @@ def __call__(

# step2: forward to generate response
response = dashscope.ImageSynthesis.call(
model=self.model,
model=self.model_name,
prompt=prompt,
n=1,
**kwargs,
Expand All @@ -317,21 +303,18 @@ def __call__(
# step3: record the model api invocation if needed
self._save_model_invocation(
arguments={
"model": self.model,
"model": self.model_name,
"prompt": prompt,
**kwargs,
},
response=response,
)

# step4: update monitor accordingly
try:
self.monitor.update(
response.usage,
prefix=self.model,
)
except Exception as e:
logger.error(e)
self.update_monitor(
call_counter=1,
**response.usage,
)

# step5: return response
images = response.output["results"]
Expand All @@ -344,17 +327,20 @@ def __call__(
return ModelResponse(image_urls=urls, raw=response)


class DashScopeTextEmbeddingWrapper(DashScopeWrapper):
class DashScopeTextEmbeddingWrapper(DashScopeWrapperBase):
"""The model wrapper for DashScope Text Embedding API."""

model_type: str = "dashscope_text_embedding"

def _register_default_metrics(self) -> None:
# Set monitor accordingly
# TODO: set quota to the following metrics
self.monitor = MonitorFactory.get_monitor()
self.monitor.register(
self._metric("total_tokens", self.model),
self._metric("call_counter"),
metric_unit="times",
)
self.monitor.register(
self._metric("total_tokens"),
metric_unit="token",
)

Expand Down Expand Up @@ -398,7 +384,7 @@ def __call__(
# step2: forward to generate response
response = dashscope.TextEmbedding.call(
input=texts,
model=self.model,
model=self.model_name,
**kwargs,
)

Expand All @@ -414,21 +400,18 @@ def __call__(
# step3: record the model api invocation if needed
self._save_model_invocation(
arguments={
"model": self.model,
"model": self.model_name,
"input": texts,
**kwargs,
},
response=response,
)

# step4: update monitor accordingly
try:
self.monitor.update(
response.usage,
prefix=self.model,
)
except Exception as e:
logger.error(e)
self.update_monitor(
call_counter=1,
**response.usage,
)

# step5: return response
if len(response.output["embeddings"]) == 0:
Expand Down
Loading
Loading