Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement model-oriented format function in OpenAI and Post API chat wrapper #381

Merged
merged 5 commits into from
Aug 5, 2024

Conversation

DavdGao
Copy link
Collaborator

@DavdGao DavdGao commented Aug 5, 2024

Background

Considering the PostAIChatWrapper will be used for many different models (e.g. gpt-4, gemini, glm-4, and so on), and OpenAI API can be used for different model services, we should choose prompt strategies in both OpenAIChatWrapper and PostAPIChatWrapper according to the model name.

Description

  • Modified format fucntion in previous model wrappers into static method so that we can call them without initializing an object.
  • Add a static function named format_for_common_chat_models in ModelWrapperBase class;
  • Implement model-oriented format function in PostAPIChatWrapper and OpenAIChatWrapper;
  • Fix bugs in previous format function, where the conversation maybe empty as follows:
formatted_prompt = [
    {
        "role": "system",
        "content": "You're a helpful assistant",
    },
    {
        "role": "user",
        "content": "## Conversation History",
    }, 
]

In this PR, we modified into to

# With other messages
formatted_prompt = [
    {
        "role": "user",
        "content": (
            "You're a helpful assistant\n"
            "\n"
            "## Conversation History\n"
            "Alice: Hi!"
    },
]

# with only system messages at the beginning
formatted_prompt = [
    {
        "role": "user",
        "content": "You're a helpful assistant\n"
    },
]
  • Modify unit test accordingly.

Checklist

Please check the following items before code is ready to be reviewed.

  • Code has passed all tests
  • Docstrings have been added/updated in Google Style
  • Documentation has been updated
  • Code is ready for review

…ugs in previous format function, where the conversation maybe empty; Modify unit test accordingly.
@DavdGao DavdGao added enhancement New feature or request ready for review labels Aug 5, 2024
Copy link
Collaborator

@garyzhang99 garyzhang99 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please see the inline comments.

src/agentscope/models/gemini_model.py Outdated Show resolved Hide resolved
src/agentscope/models/openai_model.py Show resolved Hide resolved
@DavdGao DavdGao mentioned this pull request Aug 5, 2024
@DavdGao DavdGao linked an issue Aug 5, 2024 that may be closed by this pull request
src/agentscope/models/openai_model.py Outdated Show resolved Hide resolved
src/agentscope/constants.py Show resolved Hide resolved
# Conflicts:
#	src/agentscope/models/gemini_model.py
#	src/agentscope/models/model.py
Copy link
Collaborator

@pan-x-c pan-x-c left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Collaborator

@pan-x-c pan-x-c left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@pan-x-c pan-x-c merged commit 117d8e4 into modelscope:main Aug 5, 2024
14 checks passed
@yawzhe
Copy link

yawzhe commented Aug 29, 2024

自己定义的模型包装器,发现里面就是启动一个参数传参的功能,里面也没有流式,自己写根本就无法起作用,我在里面判断请求方式,自己加stream,发现他根本就不起作用,而且,即便是删除reque_argwargs后面的所有代码,也不影响,即便是有错误,也不会有什么用,这个自定义的模型包装器,到底怎么定义,我只能非流式传输,流式的安装openai改了判断stream,使用迭代器,发现根本就不起作用,也不报错,始终是非流式输出#!/usr/bin/env python

-- coding:utf-8 --

@time : 2024/8/28 9:54

@file : local_model.py

from abc import ABC
from typing import Any, Union, Sequence, List
from loguru import logger

from agentscope.models import ModelWrapperBase
from agentscope.models import ModelResponse
import requests
from agentscope.constants import _DEFAULT_MESSAGES_KEY,_DEFAULT_MAX_RETRIES,_DEFAULT_RETRY_INTERVAL
from agentscope.message import Msg
import time
import json
class GewuModelWrapper(ModelWrapperBase):

model_type = 'ge_api_chat'

def __init__(
        self,
        config_name="",
        model_name="",
        api_url="",
        headers: dict = None,
        max_length: int = 2048,
        timeout: int = 30,
        # json_args: dict = None,
        # post_args: dict = None,
        max_retries: int = _DEFAULT_MAX_RETRIES,
        messages_key: str = _DEFAULT_MESSAGES_KEY,
        retry_interval: int = _DEFAULT_RETRY_INTERVAL,
        **kwargs: Any,

) ->None:


    super().__init__(config_name=config_name, model_name=model_name)

    self.api_url = api_url
    self.headers = headers
    self.max_length = max_length
    self.timeout = timeout
    self.json_args = {}
    self.post_args = {}
    self.max_retries = max_retries
    self.messages_key = messages_key
    self.retry_interval = retry_interval

def _parse_response(self, response: dict) -> ModelResponse:
    """Parse the response json data into ModelResponse"""
    return ModelResponse(raw=response)

def __call__(self, input_: str, **kwargs: Any) -> ModelResponse:
    """Calling the model with requests.post.

    Args:
        input_ (`str`):
            The input string to the model.

    Returns:
        `dict`: A dictionary that contains the response of the model and
        related
        information (e.g. cost, time, the number of tokens, etc.).

    Note:
        `parse_func`, `fault_handler` and `max_retries` are reserved for
        `_response_parse_decorator` to parse and check the response
        generated by model wrapper. Their usages are listed as follows:
            - `parse_func` is a callable function used to parse and check
            the response generated by the model, which takes the response
            as input.
            - `max_retries` is the maximum number of retries when the
            `parse_func` raise an exception.
            - `fault_handler` is a callable function which is called
            when the response generated by the model is invalid after
            `max_retries` retries.
    """
    # step1: prepare keyword arguments
    print(**kwargs)
    post_args = {**self.post_args, **kwargs}
    self.json_args['model_name'] = 'gewu_14b_v1'
    self.json_args['model'] = 'gewu_14b_v1'

    request_kwargs = {
        "url": self.api_url,
        "json": {self.messages_key: input_, **self.json_args},
        "headers": self.headers or {},
        **post_args,
    }

    # step2: prepare post requests
    for i in range(1, self.max_retries + 1):
        response = requests.post(**request_kwargs)
        if response.status_code == requests.codes.ok:
            break
        if i < self.max_retries:
            logger.warning(
                f"Failed to call the model with "
                f"requests.codes == {response.status_code}, retry "
                f"{i + 1}/{self.max_retries} times",
            )
            time.sleep(i * self.retry_interval)

    # step3: record model invocation
    # record the model api invocation, which will be skipped if
    # `FileManager.save_api_invocation` is `False`
    self._save_model_invocation(
        arguments=request_kwargs,
        response=response.json(),
    )

    # step4: parse the response
    if response.status_code == requests.codes.ok:
        return self._parse_response(response.json())
    else:
        logger.error(json.dumps(request_kwargs, indent=4))
        raise RuntimeError(
            f"Failed to call the model with {response.json()}",
        )

class GewuAPIChatWrapper(GewuModelWrapper):
"""A post api model wrapper compatible with openai chat, e.g., vLLM,
FastChat."""

model_type: str = "ge_api_chat"

def _parse_response(self, response: dict) -> ModelResponse:
    return ModelResponse(
        text=response["choices"][0]["message"][
            "content"
        ],
    )
def format(
        self,
        *args: Union[Msg, Sequence[Msg]],
) -> Union[List[dict]]:
    """Format the input messages into a list of dict, which is
    compatible to OpenAI Chat API.

    Args:
        args (`Union[Msg, Sequence[Msg]]`):
            The input arguments to be formatted, where each argument
            should be a `Msg` object, or a list of `Msg` objects.
            In distribution, placeholder is also allowed.

    Returns:
        `Union[List[dict]]`:
            The formatted messages.
    """
    # Format according to the potential model field in the json_args



    return ModelWrapperBase.format_for_common_chat_models(*args)

if name == 'main':

pass

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request ready for review
Projects
None yet
Development

Successfully merging this pull request may close these issues.

The POSTAPI using help
4 participants