diff --git a/pilot/model/parameter.py b/pilot/model/parameter.py index 2a3df1835..0ad048c24 100644 --- a/pilot/model/parameter.py +++ b/pilot/model/parameter.py @@ -282,9 +282,30 @@ class ProxyModelParameters(BaseModelParameters): "help": "Proxy server url, such as: https://api.openai.com/v1/chat/completions" }, ) + proxy_api_key: str = field( metadata={"tags": "privacy", "help": "The api key of current proxy LLM"}, ) + + proxy_api_base: str = field( + default=None, + metadata={ + "help": "The base api address, such as: https://api.openai.com/v1. If None, we will use proxy_api_base first" + }, + ) + + proxy_api_type: Optional[str] = field( + default=None, + metadata={ + "help": "The api type of current proxy the current proxy model, if you use Azure, it can be: azure" + }, + ) + + proxy_api_version: Optional[str] = field( + default=None, + metadata={"help": "The api version of current proxy the current model"}, + ) + http_proxy: Optional[str] = field( default=os.environ.get("http_proxy") or os.environ.get("https_proxy"), metadata={"help": "The http or https proxy to use openai"}, diff --git a/pilot/model/proxy/llms/chatgpt.py b/pilot/model/proxy/llms/chatgpt.py index 01261be8f..a2aff0b86 100644 --- a/pilot/model/proxy/llms/chatgpt.py +++ b/pilot/model/proxy/llms/chatgpt.py @@ -1,31 +1,63 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- -import json import os from typing import List +import logging import openai from pilot.model.proxy.llms.proxy_model import ProxyModel +from pilot.model.parameter import ProxyModelParameters from pilot.scene.base_message import ModelMessage, ModelMessageRoleType +logger = logging.getLogger(__name__) -def chatgpt_generate_stream( - model: ProxyModel, tokenizer, params, device, context_len=2048 -): + +def _initialize_openai(params: ProxyModelParameters): + api_type = params.proxy_api_type or os.getenv("OPENAI_API_TYPE", "open_ai") + + api_base = params.proxy_api_base or os.getenv( + "OPENAI_API_TYPE", + os.getenv("AZURE_OPENAI_ENDPOINT") if api_type == "azure" else None, + ) + api_key = params.proxy_api_key or os.getenv( + "OPENAI_API_KEY", + os.getenv("AZURE_OPENAI_KEY") if api_type == "azure" else None, + ) + api_version = params.proxy_api_version or os.getenv("OPENAI_API_VERSION") + + if not api_base and params.proxy_server_url: + # Adapt previous proxy_server_url configuration + api_base = params.proxy_server_url.split("/chat/completions")[0] + if api_type: + openai.api_type = api_type + if api_base: + openai.api_base = api_base + if api_key: + openai.api_key = api_key + if api_version: + openai.api_version = api_version + if params.http_proxy: + openai.proxy = params.http_proxy + + openai_params = { + "api_type": api_type, + "api_base": api_base, + "api_version": api_version, + "proxy": params.http_proxy, + } + + return openai_params + + +def _build_request(model: ProxyModel, params): history = [] model_params = model.get_params() - print(f"Model: {model}, model_params: {model_params}") + logger.info(f"Model: {model}, model_params: {model_params}") - proxy_api_key = model_params.proxy_api_key - if model_params.http_proxy: - openai.proxy = model_params.http_proxy - openai.api_key = os.getenv("OPENAI_API_KEY") or proxy_api_key - proxyllm_backend = model_params.proxyllm_backend - if not proxyllm_backend: - proxyllm_backend = "gpt-3.5-turbo" + openai_params = _initialize_openai(model_params) messages: List[ModelMessage] = params["messages"] # Add history conversation @@ -51,14 +83,32 @@ def chatgpt_generate_stream( history.append(last_user_input) payloads = { - "model": proxyllm_backend, # just for test, remove this later "temperature": params.get("temperature"), "max_tokens": params.get("max_new_tokens"), "stream": True, } - res = openai.ChatCompletion.create(messages=history, **payloads) + proxyllm_backend = model_params.proxyllm_backend + + if openai_params["api_type"] == "azure": + # engine = "deployment_name". + proxyllm_backend = proxyllm_backend or "gpt-35-turbo" + payloads["engine"] = proxyllm_backend + else: + proxyllm_backend = proxyllm_backend or "gpt-3.5-turbo" + payloads["model"] = proxyllm_backend - print(f"Send request to real model {proxyllm_backend}") + logger.info( + f"Send request to real model {proxyllm_backend}, openai_params: {openai_params}" + ) + return history, payloads + + +def chatgpt_generate_stream( + model: ProxyModel, tokenizer, params, device, context_len=2048 +): + history, payloads = _build_request(model, params) + + res = openai.ChatCompletion.create(messages=history, **payloads) text = "" for r in res: @@ -66,3 +116,18 @@ def chatgpt_generate_stream( content = r["choices"][0]["delta"]["content"] text += content yield text + + +async def async_chatgpt_generate_stream( + model: ProxyModel, tokenizer, params, device, context_len=2048 +): + history, payloads = _build_request(model, params) + + res = await openai.ChatCompletion.acreate(messages=history, **payloads) + + text = "" + async for r in res: + if r["choices"][0]["delta"].get("content") is not None: + content = r["choices"][0]["delta"]["content"] + text += content + yield text