From 3b893bc5916fc731a0d7bb2737ed2d88c4fc7dae Mon Sep 17 00:00:00 2001 From: Fangyin Cheng Date: Sat, 13 Jan 2024 17:56:06 +0800 Subject: [PATCH 1/5] refactor: Refactor proxy LLM --- dbgpt/app/chat_adapter.py | 19 +- dbgpt/app/scene/base_chat.py | 211 +++------ dbgpt/app/scene/operator/app_operator.py | 129 +++++- dbgpt/core/awel/dag/base.py | 6 +- dbgpt/core/awel/operator/base.py | 10 +- dbgpt/core/awel/operator/stream_operator.py | 6 + dbgpt/core/awel/runner/local_runner.py | 17 +- dbgpt/core/interface/llm.py | 40 +- dbgpt/core/interface/message.py | 37 +- dbgpt/core/interface/tests/test_message.py | 8 +- dbgpt/model/__init__.py | 4 +- dbgpt/model/adapter/base.py | 29 +- dbgpt/model/adapter/model_adapter.py | 3 +- dbgpt/model/adapter/proxy_adapter.py | 238 ++++++++++ dbgpt/model/cluster/client.py | 19 +- dbgpt/model/cluster/worker/default_worker.py | 11 +- dbgpt/model/llm_out/proxy_llm.py | 50 +- dbgpt/model/loader.py | 3 +- dbgpt/model/operator/llm_operator.py | 2 +- dbgpt/model/parameter.py | 10 + dbgpt/model/proxy/__init__.py | 15 + dbgpt/model/proxy/base.py | 242 ++++++++++ dbgpt/model/proxy/llms/baichuan.py | 1 + dbgpt/model/proxy/llms/bard.py | 1 + dbgpt/model/proxy/llms/chatgpt.py | 437 +++++++++--------- dbgpt/model/proxy/llms/gemini.py | 192 +++++--- dbgpt/model/proxy/llms/proxy_model.py | 8 +- dbgpt/model/proxy/llms/spark.py | 184 +++++--- dbgpt/model/proxy/llms/tongyi.py | 156 ++++--- dbgpt/model/proxy/llms/wenxin.py | 236 ++++++---- dbgpt/model/proxy/llms/zhipu.py | 150 +++--- dbgpt/model/utils/chatgpt_utils.py | 163 +------ .../cache/operator.py} | 111 +---- examples/awel/simple_chat_dag_example.py | 32 +- examples/awel/simple_chat_history_example.py | 2 +- examples/awel/simple_dag_example.py | 3 +- 36 files changed, 1737 insertions(+), 1048 deletions(-) create mode 100644 dbgpt/model/adapter/proxy_adapter.py create mode 100644 dbgpt/model/proxy/base.py rename dbgpt/{model/operator/model_operator.py => storage/cache/operator.py} (68%) diff --git a/dbgpt/app/chat_adapter.py b/dbgpt/app/chat_adapter.py index d0c5e84dc..d4eb3f3a7 100644 --- a/dbgpt/app/chat_adapter.py +++ b/dbgpt/app/chat_adapter.py @@ -178,14 +178,15 @@ def get_generate_stream_func(self, model_path: str): return falcon_generate_output -class ProxyllmChatAdapter(BaseChatAdpter): - def match(self, model_path: str): - return "proxyllm" in model_path - - def get_generate_stream_func(self, model_path: str): - from dbgpt.model.llm_out.proxy_llm import proxyllm_generate_stream - - return proxyllm_generate_stream +# +# class ProxyllmChatAdapter(BaseChatAdpter): +# def match(self, model_path: str): +# return "proxyllm" in model_path +# +# def get_generate_stream_func(self, model_path: str): +# from dbgpt.model.llm_out.proxy_llm import proxyllm_generate_stream +# +# return proxyllm_generate_stream class GorillaChatAdapter(BaseChatAdpter): @@ -286,6 +287,6 @@ def get_conv_template(self, model_path: str) -> Conversation: register_llm_model_chat_adapter(InternLMChatAdapter) # Proxy model for test and develop, it's cheap for us now. -register_llm_model_chat_adapter(ProxyllmChatAdapter) +# register_llm_model_chat_adapter(ProxyllmChatAdapter) register_llm_model_chat_adapter(BaseChatAdpter) diff --git a/dbgpt/app/scene/base_chat.py b/dbgpt/app/scene/base_chat.py index b9a22f7e8..e2647fb4c 100644 --- a/dbgpt/app/scene/base_chat.py +++ b/dbgpt/app/scene/base_chat.py @@ -3,7 +3,7 @@ import logging import traceback from abc import ABC, abstractmethod -from typing import Any, Dict +from typing import Any, AsyncIterator, Dict from dbgpt._private.config import Config from dbgpt._private.pydantic import Extra @@ -11,12 +11,13 @@ from dbgpt.app.scene.operator.app_operator import ( AppChatComposerOperator, ChatComposerInput, + build_cached_chat_operator, ) from dbgpt.component import ComponentType -from dbgpt.core.awel import DAG, BaseOperator, InputOperator, SimpleCallDataInputSource +from dbgpt.core import LLMClient, ModelOutput, ModelRequest, ModelRequestContext from dbgpt.core.interface.message import StorageConversation +from dbgpt.model import DefaultLLMClient from dbgpt.model.cluster import WorkerManagerFactory -from dbgpt.model.operator.model_operator import ModelOperator, ModelStreamOperator from dbgpt.serve.conversation.serve import Serve as ConversationServe from dbgpt.util import get_or_create_event_loop from dbgpt.util.executor_utils import ExecutorFactory, blocking_func_to_async @@ -64,6 +65,10 @@ class BaseChat(ABC): keep_start_rounds: int = 0 keep_end_rounds: int = 0 + # Some model not support system role, this config is used to control whether to + # convert system message to human message + auto_convert_message: bool = True + class Config: """Configuration for this pydantic object.""" @@ -124,14 +129,15 @@ def __init__(self, chat_param: Dict): ComponentType.EXECUTOR_DEFAULT, ExecutorFactory ).create() - self._model_operator: BaseOperator = _build_model_operator() - self._model_stream_operator: BaseOperator = _build_model_operator( - is_stream=True, dag_name="llm_stream_model_dag" - ) + # self._model_operator: BaseOperator = _build_model_operator() + # self._model_stream_operator: BaseOperator = _build_model_operator( + # is_stream=True, dag_name="llm_stream_model_dag" + # ) # In v1, we will transform the message to compatible format of specific model # In the future, we will upgrade the message version to v2, and the message will be compatible with all models self._message_version = chat_param.get("message_version", "v2") + self._chat_param = chat_param class Config: """Configuration for this pydantic object.""" @@ -153,6 +159,27 @@ async def generate_input_values(self) -> Dict: a dictionary to be formatted by prompt template """ + @property + def llm_client(self) -> LLMClient: + """Return the LLM client.""" + worker_manager = CFG.SYSTEM_APP.get_component( + ComponentType.WORKER_MANAGER_FACTORY, WorkerManagerFactory + ).create() + return DefaultLLMClient( + worker_manager, auto_convert_message=self.auto_convert_message + ) + + async def call_llm_operator(self, request: ModelRequest) -> ModelOutput: + llm_task = build_cached_chat_operator(self.llm_client, False, CFG.SYSTEM_APP) + return await llm_task.call(call_data={"data": request}) + + async def call_streaming_operator( + self, request: ModelRequest + ) -> AsyncIterator[ModelOutput]: + llm_task = build_cached_chat_operator(self.llm_client, True, CFG.SYSTEM_APP) + async for out in await llm_task.call_stream(call_data={"data": request}): + yield out + def do_action(self, prompt_response): return prompt_response @@ -185,7 +212,7 @@ def get_llm_speak(self, prompt_define_response): speak_to_user = prompt_define_response return speak_to_user - async def __call_base(self): + async def _build_model_request(self) -> ModelRequest: input_values = await self.generate_input_values() # Load history self.history_messages = self.current_message.get_history_message() @@ -195,21 +222,6 @@ async def __call_base(self): "%Y-%m-%d %H:%M:%S" ) self.current_message.tokens = 0 - # TODO: support handle history message by tokens - # if self.prompt_template.template: - # metadata = { - # "template_scene": self.prompt_template.template_scene, - # "input_values": input_values, - # } - # with root_tracer.start_span( - # "BaseChat.__call_base.prompt_template.format", metadata=metadata - # ): - # current_prompt = self.prompt_template.format(**input_values) - # ### prompt context token adapt according to llm max context length - # current_prompt = await self.prompt_context_token_adapt( - # prompt=current_prompt - # ) - # self.current_message.add_system_message(current_prompt) keep_start_rounds = ( self.keep_start_rounds @@ -219,11 +231,25 @@ async def __call_base(self): keep_end_rounds = ( self.keep_end_rounds if self.prompt_template.need_historical_messages else 0 ) + req_ctx = ModelRequestContext( + stream=self.prompt_template.stream_out, + user_name=self._chat_param.get("user_name"), + sys_code=self._chat_param.get("sys_code"), + chat_mode=self.chat_mode.value(), + span_id=root_tracer.get_current_span_id(), + ) node = AppChatComposerOperator( + model=self.llm_model, + temperature=float(self.prompt_template.temperature), + max_new_tokens=int(self.prompt_template.max_new_tokens), prompt=self.prompt_template.prompt, + message_version=self._message_version, + echo=self.llm_echo, + streaming=self.prompt_template.stream_out, keep_start_rounds=keep_start_rounds, keep_end_rounds=keep_end_rounds, str_history=self.prompt_template.str_history, + request_context=req_ctx, ) node_input = { "data": ChatComposerInput( @@ -231,22 +257,9 @@ async def __call_base(self): ) } # llm_messages = self.generate_llm_messages() - llm_messages = await node.call(call_data=node_input) - if not CFG.NEW_SERVER_MODE: - # Not new server mode, we convert the message format(List[ModelMessage]) to list of dict - # fix the error of "Object of type ModelMessage is not JSON serializable" when passing the payload to request.post - llm_messages = list(map(lambda m: m.dict(), llm_messages)) - - payload = { - "model": self.llm_model, - "prompt": "", - "messages": llm_messages, - "temperature": float(self.prompt_template.temperature), - "max_new_tokens": int(self.prompt_template.max_new_tokens), - "echo": self.llm_echo, - "version": self._message_version, - } - return payload + model_request: ModelRequest = await node.call(call_data=node_input) + model_request.context.cache_enable = self.model_cache_enable + return model_request def stream_plugin_call(self, text): return text @@ -271,23 +284,19 @@ def _get_span_metadata(self, payload: Dict) -> Dict: async def stream_call(self): # TODO Retry when server connection error - payload = await self.__call_base() + payload = await self._build_model_request() - self.skip_echo_len = len(payload.get("prompt").replace("", " ")) + 11 logger.info(f"payload request: \n{payload}") ai_response_text = "" span = root_tracer.start_span( - "BaseChat.stream_call", metadata=self._get_span_metadata(payload) + "BaseChat.stream_call", metadata=payload.to_dict() ) - payload["span_id"] = span.span_id - payload["model_cache_enable"] = self.model_cache_enable + payload.span_id = span.span_id try: - async for output in await self._model_stream_operator.call_stream( - call_data={"data": payload} - ): + async for output in self.call_streaming_operator(payload): # Plugin research in result generation msg = self.prompt_template.output_parser.parse_model_stream_resp_ex( - output, self.skip_echo_len + output, 0 ) view_msg = self.stream_plugin_call(msg) view_msg = view_msg.replace("\n", "\\n") @@ -308,19 +317,16 @@ async def stream_call(self): self.current_message.end_current_round() async def nostream_call(self): - payload = await self.__call_base() - logger.info(f"Request: \n{payload}") - ai_response_text = "" + payload = await self._build_model_request() span = root_tracer.start_span( - "BaseChat.nostream_call", metadata=self._get_span_metadata(payload) + "BaseChat.nostream_call", metadata=payload.to_dict() ) - payload["span_id"] = span.span_id - payload["model_cache_enable"] = self.model_cache_enable + logger.info(f"Request: \n{payload}") + ai_response_text = "" + payload.span_id = span.span_id try: with root_tracer.start_span("BaseChat.invoke_worker_manager.generate"): - model_output = await self._model_operator.call( - call_data={"data": payload} - ) + model_output = await self.call_llm_operator(payload) ### output parse ai_response_text = ( @@ -380,13 +386,12 @@ async def nostream_call(self): return self.current_ai_response() async def get_llm_response(self): - payload = await self.__call_base() + payload = await self._build_model_request() logger.info(f"Request: \n{payload}") ai_response_text = "" - payload["model_cache_enable"] = self.model_cache_enable prompt_define_response = None try: - model_output = await self._model_operator.call(call_data={"data": payload}) + model_output = await self.call_llm_operator(payload) ### output parse ai_response_text = ( self.prompt_template.output_parser.parse_model_nostream_resp( @@ -544,89 +549,3 @@ def _generate_numbered_list(self) -> str: for dict_item in antv_charts for key, value in dict_item.items() ) - - -def _build_model_operator( - is_stream: bool = False, dag_name: str = "llm_model_dag" -) -> BaseOperator: - """Builds and returns a model processing workflow (DAG) operator. - - This function constructs a Directed Acyclic Graph (DAG) for processing data using a model. - It includes caching and branching logic to either fetch results from a cache or process - data using the model. It supports both streaming and non-streaming modes. - - .. code-block:: python - - input_node >> cache_check_branch_node - cache_check_branch_node >> model_node >> save_cached_node >> join_node - cache_check_branch_node >> cached_node >> join_node - - equivalent to:: - - -> model_node -> save_cached_node -> - / \ - input_node -> cache_check_branch_node ---> join_node - \ / - -> cached_node ------------------- -> - - Args: - is_stream (bool): Flag to determine if the operator should process data in streaming mode. - dag_name (str): Name of the DAG. - - Returns: - BaseOperator: The final operator in the constructed DAG, typically a join node. - """ - from dbgpt.core.awel import JoinOperator - from dbgpt.model.cluster import WorkerManagerFactory - from dbgpt.model.operator.model_operator import ( - CachedModelOperator, - CachedModelStreamOperator, - ModelCacheBranchOperator, - ModelSaveCacheOperator, - ModelStreamSaveCacheOperator, - ) - from dbgpt.storage.cache import CacheManager - - # Fetch worker and cache managers from the system configuration - worker_manager = CFG.SYSTEM_APP.get_component( - ComponentType.WORKER_MANAGER_FACTORY, WorkerManagerFactory - ).create() - cache_manager: CacheManager = CFG.SYSTEM_APP.get_component( - ComponentType.MODEL_CACHE_MANAGER, CacheManager - ) - # Define task names for the model and cache nodes - model_task_name = "llm_model_node" - cache_task_name = "llm_model_cache_node" - - with DAG(dag_name): - # Create an input node - input_node = InputOperator(SimpleCallDataInputSource()) - # Determine if the workflow should operate in streaming mode - if is_stream: - model_node = ModelStreamOperator(worker_manager, task_name=model_task_name) - cached_node = CachedModelStreamOperator( - cache_manager, task_name=cache_task_name - ) - save_cached_node = ModelStreamSaveCacheOperator(cache_manager) - else: - model_node = ModelOperator(worker_manager, task_name=model_task_name) - cached_node = CachedModelOperator(cache_manager, task_name=cache_task_name) - save_cached_node = ModelSaveCacheOperator(cache_manager) - - # Create a branch node to decide between fetching from cache or processing with the model - cache_check_branch_node = ModelCacheBranchOperator( - cache_manager, - model_task_name="llm_model_node", - cache_task_name="llm_model_cache_node", - ) - # Create a join node to merge outputs from the model and cache nodes, just keep the first not empty output - join_node = JoinOperator( - combine_function=lambda model_out, cache_out: cache_out or model_out - ) - - # Define the workflow structure using the >> operator - input_node >> cache_check_branch_node - cache_check_branch_node >> model_node >> save_cached_node >> join_node - cache_check_branch_node >> cached_node >> join_node - - return join_node diff --git a/dbgpt/app/scene/operator/app_operator.py b/dbgpt/app/scene/operator/app_operator.py index 799b89505..ef5062b9d 100644 --- a/dbgpt/app/scene/operator/app_operator.py +++ b/dbgpt/app/scene/operator/app_operator.py @@ -1,17 +1,36 @@ import dataclasses from typing import Any, Dict, List, Optional -from dbgpt.core import BaseMessage, ChatPromptTemplate, ModelMessage +from dbgpt import SystemApp +from dbgpt.component import ComponentType +from dbgpt.core import ( + BaseMessage, + ChatPromptTemplate, + LLMClient, + ModelRequest, + ModelRequestContext, +) from dbgpt.core.awel import ( DAG, BaseOperator, InputOperator, + JoinOperator, MapOperator, SimpleCallDataInputSource, ) from dbgpt.core.operator import ( BufferedConversationMapperOperator, HistoryPromptBuilderOperator, + LLMBranchOperator, +) +from dbgpt.model.operator import LLMOperator, StreamingLLMOperator +from dbgpt.storage.cache.operator import ( + CachedModelOperator, + CachedModelStreamOperator, + CacheManager, + ModelCacheBranchOperator, + ModelSaveCacheOperator, + ModelStreamSaveCacheOperator, ) @@ -23,7 +42,7 @@ class ChatComposerInput: prompt_dict: Dict[str, Any] -class AppChatComposerOperator(MapOperator[ChatComposerInput, List[ModelMessage]]): +class AppChatComposerOperator(MapOperator[ChatComposerInput, ModelRequest]): """App chat composer operator. TODO: Support more history merge mode. @@ -31,29 +50,55 @@ class AppChatComposerOperator(MapOperator[ChatComposerInput, List[ModelMessage]] def __init__( self, + model: str, + temperature: float, + max_new_tokens: int, prompt: ChatPromptTemplate, + message_version: str = "v2", + echo: bool = False, + streaming: bool = True, history_key: str = "chat_history", history_merge_mode: str = "window", keep_start_rounds: Optional[int] = None, keep_end_rounds: Optional[int] = None, str_history: bool = False, + request_context: ModelRequestContext = None, **kwargs, ): super().__init__(**kwargs) + if not request_context: + request_context = ModelRequestContext(stream=streaming) self._prompt_template = prompt self._history_key = history_key self._history_merge_mode = history_merge_mode self._keep_start_rounds = keep_start_rounds self._keep_end_rounds = keep_end_rounds self._str_history = str_history + self._model_name = model + self._temperature = temperature + self._max_new_tokens = max_new_tokens + self._message_version = message_version + self._echo = echo + self._streaming = streaming + self._request_context = request_context self._sub_compose_dag = self._build_composer_dag() - async def map(self, input_value: ChatComposerInput) -> List[ModelMessage]: + async def map(self, input_value: ChatComposerInput) -> ModelRequest: end_node: BaseOperator = self._sub_compose_dag.leaf_nodes[0] # Sub dag, use the same dag context in the parent dag - return await end_node.call( + messages = await end_node.call( call_data={"data": input_value}, dag_ctx=self.current_dag_context ) + span_id = self._request_context.span_id + model_request = ModelRequest.build_request( + model=self._model_name, + messages=messages, + context=self._request_context, + temperature=self._temperature, + max_new_tokens=self._max_new_tokens, + span_id=span_id, + ) + return model_request def _build_composer_dag(self) -> DAG: with DAG("dbgpt_awel_app_chat_history_prompt_composer") as composer_dag: @@ -83,3 +128,79 @@ def _build_composer_dag(self) -> DAG: ) return composer_dag + + +def build_cached_chat_operator( + llm_client: LLMClient, + is_streaming: bool, + system_app: SystemApp, + cache_manager: Optional[CacheManager] = None, +): + """Builds and returns a model processing workflow (DAG) operator. + + This function constructs a Directed Acyclic Graph (DAG) for processing data using a model. + It includes caching and branching logic to either fetch results from a cache or process + data using the model. It supports both streaming and non-streaming modes. + + .. code-block:: python + + input_task >> cache_check_branch_task + cache_check_branch_task >> llm_task >> save_cache_task >> join_task + cache_check_branch_task >> cache_task >> join_task + + equivalent to:: + + -> llm_task -> save_cache_task -> + / \ + input_task -> cache_check_branch_task ---> join_task + \ / + -> cache_task ------------------- -> + + Args: + llm_client (LLMClient): The LLM client for processing data using the model. + is_streaming (bool): Whether the model is a streaming model. + system_app (SystemApp): The system app. + cache_manager (CacheManager, optional): The cache manager for managing cache operations. Defaults to None. + + Returns: + BaseOperator: The final operator in the constructed DAG, typically a join node. + """ + # Define task names for the model and cache nodes + model_task_name = "llm_model_node" + cache_task_name = "llm_model_cache_node" + if not cache_manager: + cache_manager: CacheManager = system_app.get_component( + ComponentType.MODEL_CACHE_MANAGER, CacheManager + ) + + with DAG("dbgpt_awel_app_model_infer_with_cached") as dag: + # Create an input task + input_task = InputOperator(SimpleCallDataInputSource()) + # Create a branch task to decide between fetching from cache or processing with the model + if is_streaming: + llm_task = StreamingLLMOperator(llm_client, task_name=model_task_name) + cache_task = CachedModelStreamOperator( + cache_manager, task_name=cache_task_name + ) + save_cache_task = ModelStreamSaveCacheOperator(cache_manager) + else: + llm_task = LLMOperator(llm_client, task_name=model_task_name) + cache_task = CachedModelOperator(cache_manager, task_name=cache_task_name) + save_cache_task = ModelSaveCacheOperator(cache_manager) + + # Create a branch node to decide between fetching from cache or processing with the model + cache_check_branch_task = ModelCacheBranchOperator( + cache_manager, + model_task_name=model_task_name, + cache_task_name=cache_task_name, + ) + # Create a join node to merge outputs from the model and cache nodes, just keep the first not empty output + join_task = JoinOperator( + combine_function=lambda model_out, cache_out: cache_out or model_out + ) + + # Define the workflow structure using the >> operator + input_task >> cache_check_branch_task + cache_check_branch_task >> llm_task >> save_cache_task >> join_task + cache_check_branch_task >> cache_task >> join_task + return join_task diff --git a/dbgpt/core/awel/dag/base.py b/dbgpt/core/awel/dag/base.py index 70977d536..7f1017c7a 100644 --- a/dbgpt/core/awel/dag/base.py +++ b/dbgpt/core/awel/dag/base.py @@ -384,7 +384,7 @@ async def get_from_share_data(self, key: str) -> Any: return self._share_data.get(key) async def save_to_share_data( - self, key: str, data: Any, overwrite: Optional[str] = None + self, key: str, data: Any, overwrite: bool = False ) -> None: if key in self._share_data and not overwrite: raise ValueError(f"Share data key {key} already exists") @@ -407,7 +407,7 @@ async def get_task_share_data(self, task_name: str, key: str) -> Any: return self.get_from_share_data(_build_task_key(task_name, key)) async def save_task_share_data( - self, task_name: str, key: str, data: Any, overwrite: Optional[str] = None + self, task_name: str, key: str, data: Any, overwrite: bool = False ) -> None: """Save share data by task name and key @@ -415,7 +415,7 @@ async def save_task_share_data( task_name (str): The task name key (str): The share data key data (Any): The share data - overwrite (Optional[str], optional): Whether overwrite the share data if the key already exists. + overwrite (bool): Whether overwrite the share data if the key already exists. Defaults to None. Raises: diff --git a/dbgpt/core/awel/operator/base.py b/dbgpt/core/awel/operator/base.py index 8fa3e6905..b56ed9b86 100644 --- a/dbgpt/core/awel/operator/base.py +++ b/dbgpt/core/awel/operator/base.py @@ -46,7 +46,7 @@ async def execute_workflow( node: "BaseOperator", call_data: Optional[CALL_DATA] = None, streaming_call: bool = False, - dag_ctx: Optional[DAGContext] = None, + exist_dag_ctx: Optional[DAGContext] = None, ) -> DAGContext: """Execute the workflow starting from a given operator. @@ -54,7 +54,7 @@ async def execute_workflow( node (RunnableDAGNode): The starting node of the workflow to be executed. call_data (CALL_DATA): The data pass to root operator node. streaming_call (bool): Whether the call is a streaming call. - dag_ctx (DAGContext): The context of the DAG when this node is run, Defaults to None. + exist_dag_ctx (DAGContext): The context of the DAG when this node is run, Defaults to None. Returns: DAGContext: The context after executing the workflow, containing the final state and data. """ @@ -190,7 +190,9 @@ async def call( Returns: OUT: The output of the node after execution. """ - out_ctx = await self._runner.execute_workflow(self, call_data, dag_ctx=dag_ctx) + out_ctx = await self._runner.execute_workflow( + self, call_data, exist_dag_ctx=dag_ctx + ) return out_ctx.current_task_context.task_output.output def _blocking_call( @@ -230,7 +232,7 @@ async def call_stream( AsyncIterator[OUT]: An asynchronous iterator over the output stream. """ out_ctx = await self._runner.execute_workflow( - self, call_data, streaming_call=True, dag_ctx=dag_ctx + self, call_data, streaming_call=True, exist_dag_ctx=dag_ctx ) return out_ctx.current_task_context.task_output.output_stream diff --git a/dbgpt/core/awel/operator/stream_operator.py b/dbgpt/core/awel/operator/stream_operator.py index 5927a43fc..526249704 100644 --- a/dbgpt/core/awel/operator/stream_operator.py +++ b/dbgpt/core/awel/operator/stream_operator.py @@ -9,6 +9,12 @@ class StreamifyAbsOperator(BaseOperator[OUT], ABC, Generic[IN, OUT]): async def _do_run(self, dag_ctx: DAGContext) -> TaskOutput[OUT]: curr_task_ctx: TaskContext[OUT] = dag_ctx.current_task_context + call_data = curr_task_ctx.call_data + if call_data: + call_data = await curr_task_ctx._call_data_to_output() + output = await call_data.streamify(self.streamify) + curr_task_ctx.set_task_output(output) + return output output = await curr_task_ctx.task_input.parent_outputs[0].task_output.streamify( self.streamify ) diff --git a/dbgpt/core/awel/runner/local_runner.py b/dbgpt/core/awel/runner/local_runner.py index 680b6f974..6fb69619b 100644 --- a/dbgpt/core/awel/runner/local_runner.py +++ b/dbgpt/core/awel/runner/local_runner.py @@ -19,21 +19,22 @@ async def execute_workflow( node: BaseOperator, call_data: Optional[CALL_DATA] = None, streaming_call: bool = False, - dag_ctx: Optional[DAGContext] = None, + exist_dag_ctx: Optional[DAGContext] = None, ) -> DAGContext: # Save node output # dag = node.dag job_manager = JobManager.build_from_end_node(node, call_data) - if not dag_ctx: + if not exist_dag_ctx: # Create DAG context node_outputs: Dict[str, TaskContext] = {} - dag_ctx = DAGContext( - streaming_call=streaming_call, - node_to_outputs=node_outputs, - node_name_to_ids=job_manager._node_name_to_ids, - ) else: - node_outputs = dag_ctx._node_to_outputs + # Share node output with exist dag context + node_outputs = exist_dag_ctx._node_to_outputs + dag_ctx = DAGContext( + streaming_call=streaming_call, + node_to_outputs=node_outputs, + node_name_to_ids=job_manager._node_name_to_ids, + ) logger.info( f"Begin run workflow from end operator, id: {node.node_id}, call_data: {call_data}" ) diff --git a/dbgpt/core/interface/llm.py b/dbgpt/core/interface/llm.py index 7b705dd77..21050527e 100644 --- a/dbgpt/core/interface/llm.py +++ b/dbgpt/core/interface/llm.py @@ -127,6 +127,11 @@ class ModelRequestContext: extra: Optional[Dict[str, Any]] = field(default_factory=dict) """The extra information of the model inference.""" + request_id: Optional[str] = None + """The request id of the model inference.""" + cache_enable: Optional[bool] = False + """Whether to enable the cache for the model inference""" + @dataclass @PublicAPI(stability="beta") @@ -205,6 +210,11 @@ def to_dict(self) -> Dict[str, Any]: # Skip None fields return {k: v for k, v in asdict(new_reqeust).items() if v} + def to_trace_metadata(self): + metadata = self.to_dict() + metadata["prompt"] = self.messages_to_string() + return metadata + def get_messages(self) -> List[ModelMessage]: """Get the messages. @@ -234,10 +244,12 @@ def get_single_user_message(self) -> Optional[ModelMessage]: def build_request( model: str, messages: List[ModelMessage], - context: Union[ModelRequestContext, Dict[str, Any], BaseModel], + context: Optional[Union[ModelRequestContext, Dict[str, Any], BaseModel]] = None, stream: Optional[bool] = False, **kwargs, ): + if not context: + context = ModelRequestContext(stream=stream) context_dict = None if isinstance(context, dict): context_dict = context @@ -261,14 +273,22 @@ def _build(model: str, prompt: str, **kwargs): **kwargs, ) - def to_openai_messages(self) -> List[Dict[str, Any]]: - """Convert the messages to the format of OpenAI API. + def to_common_messages( + self, support_system_role: bool = True + ) -> List[Dict[str, Any]]: + """Convert the messages to the common format(like OpenAI API). This function will move last user message to the end of the list. + Args: + support_system_role (bool): Whether to support system role + Returns: List[Dict[str, Any]]: The messages in the format of OpenAI API. + Raises: + ValueError: If the message role is not supported + Examples: .. code-block:: python @@ -298,7 +318,17 @@ def to_openai_messages(self) -> List[Dict[str, Any]]: m if isinstance(m, ModelMessage) else ModelMessage(**m) for m in self.messages ] - return ModelMessage.to_openai_messages(messages) + return ModelMessage.to_common_messages( + messages, support_system_role=support_system_role + ) + + def messages_to_string(self) -> str: + """Convert the messages to string. + + Returns: + str: The messages in string format. + """ + return ModelMessage.messages_to_string(self.get_messages()) @dataclass @@ -478,7 +508,7 @@ def convert( if not model_metadata or not model_metadata.ext_metadata: logger.warning("No model metadata, skip message system message conversion") return messages - if model_metadata.ext_metadata.support_system_message: + if not model_metadata.ext_metadata.support_system_message: # 3. Convert the messages to no system message return self.convert_to_no_system_message(messages, model_metadata) return messages diff --git a/dbgpt/core/interface/message.py b/dbgpt/core/interface/message.py index 25ba8b665..d6b548f55 100755 --- a/dbgpt/core/interface/message.py +++ b/dbgpt/core/interface/message.py @@ -197,15 +197,24 @@ def from_openai_messages( return result @staticmethod - def to_openai_messages( - messages: List["ModelMessage"], convert_to_compatible_format: bool = False + def to_common_messages( + messages: List["ModelMessage"], + convert_to_compatible_format: bool = False, + support_system_role: bool = True, ) -> List[Dict[str, str]]: - """Convert to OpenAI message format and - hugggingface [Templates of Chat Models](https://huggingface.co/docs/transformers/v4.34.1/en/chat_templating) + """Convert to common message format(e.g. OpenAI message format) and + huggingface [Templates of Chat Models](https://huggingface.co/docs/transformers/v4.34.1/en/chat_templating) Args: messages (List["ModelMessage"]): The model messages convert_to_compatible_format (bool): Whether to convert to compatible format + support_system_role (bool): Whether to support system role + + Returns: + List[Dict[str, str]]: The common messages + + Raises: + ValueError: If the message role is not supported """ history = [] # Add history conversation @@ -213,6 +222,8 @@ def to_openai_messages( if message.role == ModelMessageRoleType.HUMAN: history.append({"role": "user", "content": message.content}) elif message.role == ModelMessageRoleType.SYSTEM: + if not support_system_role: + raise ValueError("Current model not support system role") history.append({"role": "system", "content": message.content}) elif message.role == ModelMessageRoleType.AI: history.append({"role": "assistant", "content": message.content}) @@ -250,6 +261,18 @@ def get_printable_message(messages: List["ModelMessage"]) -> str: return str_msg + @staticmethod + def messages_to_string(messages: List["ModelMessage"]) -> str: + """Convert messages to str + + Args: + messages (List[ModelMessage]): The messages + + Returns: + str: The str messages + """ + return _messages_to_str(messages) + _SingleRoundMessage = List[BaseMessage] _MultiRoundMessageMapper = Callable[[List[_SingleRoundMessage]], List[BaseMessage]] @@ -264,7 +287,7 @@ def _messages_to_dict(messages: List[BaseMessage]) -> List[Dict]: def _messages_to_str( - messages: List[BaseMessage], + messages: List[Union[BaseMessage, ModelMessage]], human_prefix: str = "Human", ai_prefix: str = "AI", system_prefix: str = "System", @@ -272,7 +295,7 @@ def _messages_to_str( """Convert messages to str Args: - messages (List[BaseMessage]): The messages + messages (List[Union[BaseMessage, ModelMessage]]): The messages human_prefix (str): The human prefix ai_prefix (str): The ai prefix system_prefix (str): The system prefix @@ -291,6 +314,8 @@ def _messages_to_str( role = system_prefix elif isinstance(message, ViewMessage): pass + elif isinstance(message, ModelMessage): + role = message.role else: raise ValueError(f"Got unsupported message type: {message}") if role: diff --git a/dbgpt/core/interface/tests/test_message.py b/dbgpt/core/interface/tests/test_message.py index 72023c36f..2dd12f4ed 100755 --- a/dbgpt/core/interface/tests/test_message.py +++ b/dbgpt/core/interface/tests/test_message.py @@ -421,13 +421,13 @@ def test_parse_model_messages_multiple_system_messages(): def test_to_openai_messages( human_model_message, ai_model_message, system_model_message ): - none_messages = ModelMessage.to_openai_messages([]) + none_messages = ModelMessage.to_common_messages([]) assert none_messages == [] - single_messages = ModelMessage.to_openai_messages([human_model_message]) + single_messages = ModelMessage.to_common_messages([human_model_message]) assert single_messages == [{"role": "user", "content": human_model_message.content}] - normal_messages = ModelMessage.to_openai_messages( + normal_messages = ModelMessage.to_common_messages( [ system_model_message, human_model_message, @@ -446,7 +446,7 @@ def test_to_openai_messages( def test_to_openai_messages_convert_to_compatible_format( human_model_message, ai_model_message, system_model_message ): - shuffle_messages = ModelMessage.to_openai_messages( + shuffle_messages = ModelMessage.to_common_messages( [ system_model_message, human_model_message, diff --git a/dbgpt/model/__init__.py b/dbgpt/model/__init__.py index 28054fe9d..dacfc9a50 100644 --- a/dbgpt/model/__init__.py +++ b/dbgpt/model/__init__.py @@ -1,5 +1,7 @@ from dbgpt.model.cluster.client import DefaultLLMClient -from dbgpt.model.utils.chatgpt_utils import OpenAILLMClient + +# from dbgpt.model.utils.chatgpt_utils import OpenAILLMClient +from dbgpt.model.proxy.llms.chatgpt import OpenAILLMClient __ALL__ = [ "DefaultLLMClient", diff --git a/dbgpt/model/adapter/base.py b/dbgpt/model/adapter/base.py index 793f63c8a..06089b1a4 100644 --- a/dbgpt/model/adapter/base.py +++ b/dbgpt/model/adapter/base.py @@ -152,6 +152,17 @@ def get_default_message_separator(self) -> str: except Exception: return "\n" + def get_prompt_roles(self) -> List[str]: + """Get the roles of the prompt + + Returns: + List[str]: The roles of the prompt + """ + roles = [ModelMessageRoleType.HUMAN, ModelMessageRoleType.AI] + if self.support_system_message: + roles.append(ModelMessageRoleType.SYSTEM) + return roles + def transform_model_messages( self, messages: List[ModelMessage], convert_to_compatible_format: bool = False ) -> List[Dict[str, str]]: @@ -185,7 +196,7 @@ def transform_model_messages( # We will not do any transform in the future return self._transform_to_no_system_messages(messages) else: - return ModelMessage.to_openai_messages( + return ModelMessage.to_common_messages( messages, convert_to_compatible_format=convert_to_compatible_format ) @@ -216,7 +227,7 @@ def _transform_to_no_system_messages( Returns: List[Dict[str, str]]: The transformed model messages """ - openai_messages = ModelMessage.to_openai_messages(messages) + openai_messages = ModelMessage.to_common_messages(messages) system_messages = [] return_messages = [] for message in openai_messages: @@ -394,6 +405,9 @@ def _set_conv_converted_messages( conv.set_system_message("".join(can_use_systems)) return conv + def apply_conv_template(self) -> bool: + return self.model_type() != ModelType.PROXY + def model_adaptation( self, params: Dict, @@ -414,7 +428,11 @@ def model_adaptation( params["convert_to_compatible_format"] = convert_to_compatible_format # Some model context to dbgpt server - model_context = {"prompt_echo_len_char": -1, "has_format_prompt": False} + model_context = { + "prompt_echo_len_char": -1, + "has_format_prompt": False, + "echo": params.get("echo", True), + } if messages: # Dict message to ModelMessage messages = [ @@ -423,6 +441,10 @@ def model_adaptation( ] params["messages"] = messages + if not self.apply_conv_template(): + # No need to apply conversation template, now for proxy LLM + return params, model_context + new_prompt = self.get_str_prompt( params, messages, tokenizer, prompt_template, convert_to_compatible_format ) @@ -442,7 +464,6 @@ def model_adaptation( # TODO remote bos token and eos token from tokenizer_config.json of model prompt_echo_len_char = len(new_prompt.replace("", "").replace("", "")) model_context["prompt_echo_len_char"] = prompt_echo_len_char - model_context["echo"] = params.get("echo", True) model_context["has_format_prompt"] = True params["prompt"] = new_prompt diff --git a/dbgpt/model/adapter/model_adapter.py b/dbgpt/model/adapter/model_adapter.py index b4d06a432..3bdebe946 100644 --- a/dbgpt/model/adapter/model_adapter.py +++ b/dbgpt/model/adapter/model_adapter.py @@ -19,7 +19,7 @@ _OLD_MODELS = [ "llama-cpp", - "proxyllm", + # "proxyllm", "gptj-6b", "codellama-13b-sql-sft", "codellama-7b", @@ -45,6 +45,7 @@ def get_llm_model_adapter( # Import NewHFChatModelAdapter for it can be registered from dbgpt.model.adapter.hf_adapter import NewHFChatModelAdapter + from dbgpt.model.adapter.proxy_adapter import ProxyLLMModelAdapter new_model_adapter = get_model_adapter( model_type, model_name, model_path, conv_factory diff --git a/dbgpt/model/adapter/proxy_adapter.py b/dbgpt/model/adapter/proxy_adapter.py new file mode 100644 index 000000000..6be555c0b --- /dev/null +++ b/dbgpt/model/adapter/proxy_adapter.py @@ -0,0 +1,238 @@ +import dataclasses +import logging +from abc import abstractmethod +from typing import TYPE_CHECKING, List, Optional, Type, Union + +from dbgpt.model.adapter.base import LLMModelAdapter, register_model_adapter +from dbgpt.model.adapter.template import ConversationAdapter, ConversationAdapterFactory +from dbgpt.model.base import ModelType +from dbgpt.model.parameter import ProxyModelParameters +from dbgpt.model.proxy.base import ProxyLLMClient +from dbgpt.model.proxy.llms.proxy_model import ProxyModel + +logger = logging.getLogger(__name__) + + +class ProxyLLMModelAdapter(LLMModelAdapter): + def new_adapter(self, **kwargs) -> "LLMModelAdapter": + return self.__class__() + + def model_type(self) -> str: + return ModelType.PROXY + + def match( + self, + model_type: str, + model_name: Optional[str] = None, + model_path: Optional[str] = None, + ) -> bool: + model_name = model_name.lower() if model_name else None + model_path = model_path.lower() if model_path else None + return self.do_match(model_name) or self.do_match(model_path) + + @abstractmethod + def do_match(self, lower_model_name_or_path: Optional[str] = None): + raise NotImplementedError() + + def dynamic_llm_client_class( + self, params: ProxyModelParameters + ) -> Optional[Type[ProxyLLMClient]]: + """Get dynamic llm client class + + Parse the llm_client_class from params and return the class + + Args: + params (ProxyModelParameters): proxy model parameters + + Returns: + Optional[Type[ProxyLLMClient]]: llm client class + """ + + if params.llm_client_class: + from dbgpt.util.module_utils import import_from_checked_string + + worker_cls: Type[ProxyLLMClient] = import_from_checked_string( + params.llm_client_class, ProxyLLMClient + ) + return worker_cls + return None + + def get_llm_client_class( + self, params: ProxyModelParameters + ) -> Type[ProxyLLMClient]: + """Get llm client class""" + dynamic_llm_client_class = self.dynamic_llm_client_class(params) + if dynamic_llm_client_class: + return dynamic_llm_client_class + raise NotImplementedError() + + def load_from_params(self, params: ProxyModelParameters): + dynamic_llm_client_class = self.dynamic_llm_client_class(params) + if not dynamic_llm_client_class: + dynamic_llm_client_class = self.get_llm_client_class(params) + logger.info( + f"Load model from params: {params}, llm client class: {dynamic_llm_client_class}" + ) + proxy_llm_client = dynamic_llm_client_class.new_client(params) + model = ProxyModel(params, proxy_llm_client) + return model, model + + +class OpenAIProxyLLMModelAdapter(ProxyLLMModelAdapter): + def support_async(self) -> bool: + return True + + def do_match(self, lower_model_name_or_path: Optional[str] = None): + return lower_model_name_or_path in ["chatgpt_proxyllm", "proxyllm"] + + def get_llm_client_class( + self, params: ProxyModelParameters + ) -> Type[ProxyLLMClient]: + """Get llm client class""" + from dbgpt.model.proxy.llms.chatgpt import OpenAILLMClient + + return OpenAILLMClient + + def get_async_generate_stream_function(self, model, model_path: str): + from dbgpt.model.proxy.llms.chatgpt import chatgpt_generate_stream + + return chatgpt_generate_stream + + +class TongyiProxyLLMModelAdapter(ProxyLLMModelAdapter): + def do_match(self, lower_model_name_or_path: Optional[str] = None): + return lower_model_name_or_path == "tongyi_proxyllm" + + def get_llm_client_class( + self, params: ProxyModelParameters + ) -> Type[ProxyLLMClient]: + from dbgpt.model.proxy.llms.tongyi import TongyiLLMClient + + return TongyiLLMClient + + def get_generate_stream_function(self, model, model_path: str): + from dbgpt.model.proxy.llms.tongyi import tongyi_generate_stream + + return tongyi_generate_stream + + +class ZhipuProxyLLMModelAdapter(ProxyLLMModelAdapter): + support_system_message = False + + def do_match(self, lower_model_name_or_path: Optional[str] = None): + return lower_model_name_or_path == "zhipu_proxyllm" + + def get_llm_client_class( + self, params: ProxyModelParameters + ) -> Type[ProxyLLMClient]: + from dbgpt.model.proxy.llms.zhipu import ZhipuLLMClient + + return ZhipuLLMClient + + def get_generate_stream_function(self, model, model_path: str): + from dbgpt.model.proxy.llms.zhipu import zhipu_generate_stream + + return zhipu_generate_stream + + +class WenxinProxyLLMModelAdapter(ProxyLLMModelAdapter): + def do_match(self, lower_model_name_or_path: Optional[str] = None): + return lower_model_name_or_path == "wenxin_proxyllm" + + def get_llm_client_class( + self, params: ProxyModelParameters + ) -> Type[ProxyLLMClient]: + from dbgpt.model.proxy.llms.wenxin import WenxinLLMClient + + return WenxinLLMClient + + def get_generate_stream_function(self, model, model_path: str): + from dbgpt.model.proxy.llms.wenxin import wenxin_generate_stream + + return wenxin_generate_stream + + +class GeminiProxyLLMModelAdapter(ProxyLLMModelAdapter): + support_system_message = False + + def do_match(self, lower_model_name_or_path: Optional[str] = None): + return lower_model_name_or_path == "gemini_proxyllm" + + def get_llm_client_class( + self, params: ProxyModelParameters + ) -> Type[ProxyLLMClient]: + from dbgpt.model.proxy.llms.gemini import GeminiLLMClient + + return GeminiLLMClient + + def get_generate_stream_function(self, model, model_path: str): + from dbgpt.model.proxy.llms.gemini import gemini_generate_stream + + return gemini_generate_stream + + +class SparkProxyLLMModelAdapter(ProxyLLMModelAdapter): + support_system_message = False + + def do_match(self, lower_model_name_or_path: Optional[str] = None): + return lower_model_name_or_path == "spark_proxyllm" + + def get_llm_client_class( + self, params: ProxyModelParameters + ) -> Type[ProxyLLMClient]: + from dbgpt.model.proxy.llms.spark import SparkLLMClient + + return SparkLLMClient + + def get_generate_stream_function(self, model, model_path: str): + from dbgpt.model.proxy.llms.spark import spark_generate_stream + + return spark_generate_stream + + +class BardProxyLLMModelAdapter(ProxyLLMModelAdapter): + def do_match(self, lower_model_name_or_path: Optional[str] = None): + return lower_model_name_or_path == "bard_proxyllm" + + def get_llm_client_class( + self, params: ProxyModelParameters + ) -> Type[ProxyLLMClient]: + """Get llm client class""" + # TODO: Bard proxy LLM not support ProxyLLMClient now, we just return OpenAILLMClient + from dbgpt.model.proxy.llms.chatgpt import OpenAILLMClient + + return OpenAILLMClient + + def get_async_generate_stream_function(self, model, model_path: str): + from dbgpt.model.proxy.llms.bard import bard_generate_stream + + return bard_generate_stream + + +class BaichuanProxyLLMModelAdapter(ProxyLLMModelAdapter): + def do_match(self, lower_model_name_or_path: Optional[str] = None): + return lower_model_name_or_path == "bc_proxyllm" + + def get_llm_client_class( + self, params: ProxyModelParameters + ) -> Type[ProxyLLMClient]: + """Get llm client class""" + # TODO: Baichuan proxy LLM not support ProxyLLMClient now, we just return OpenAILLMClient + from dbgpt.model.proxy.llms.chatgpt import OpenAILLMClient + + return OpenAILLMClient + + def get_async_generate_stream_function(self, model, model_path: str): + from dbgpt.model.proxy.llms.baichuan import baichuan_generate_stream + + return baichuan_generate_stream + + +register_model_adapter(OpenAIProxyLLMModelAdapter) +register_model_adapter(TongyiProxyLLMModelAdapter) +register_model_adapter(ZhipuProxyLLMModelAdapter) +register_model_adapter(WenxinProxyLLMModelAdapter) +register_model_adapter(GeminiProxyLLMModelAdapter) +register_model_adapter(SparkProxyLLMModelAdapter) +register_model_adapter(BardProxyLLMModelAdapter) +register_model_adapter(BaichuanProxyLLMModelAdapter) diff --git a/dbgpt/model/cluster/client.py b/dbgpt/model/cluster/client.py index 426188312..bb2ea89e2 100644 --- a/dbgpt/model/cluster/client.py +++ b/dbgpt/model/cluster/client.py @@ -2,6 +2,7 @@ from typing import AsyncIterator, List, Optional from dbgpt.core.interface.llm import ( + DefaultMessageConverter, LLMClient, MessageConverter, ModelMetadata, @@ -13,14 +14,28 @@ class DefaultLLMClient(LLMClient): - def __init__(self, worker_manager: WorkerManager): + """Default LLM client implementation. + + Connect to the worker manager and send the request to the worker manager. + + Args: + worker_manager (WorkerManager): worker manager instance. + auto_convert_message (bool, optional): auto convert the message to ModelRequest. Defaults to False. + """ + + def __init__( + self, worker_manager: WorkerManager, auto_convert_message: bool = False + ): self._worker_manager = worker_manager + self._auto_covert_message = auto_convert_message async def generate( self, request: ModelRequest, message_converter: Optional[MessageConverter] = None, ) -> ModelOutput: + if not message_converter and self._auto_covert_message: + message_converter = DefaultMessageConverter() request = await self.covert_message(request, message_converter) return await self._worker_manager.generate(request.to_dict()) @@ -29,6 +44,8 @@ async def generate_stream( request: ModelRequest, message_converter: Optional[MessageConverter] = None, ) -> AsyncIterator[ModelOutput]: + if not message_converter and self._auto_covert_message: + message_converter = DefaultMessageConverter() request = await self.covert_message(request, message_converter) async for output in self._worker_manager.generate_stream(request.to_dict()): yield output diff --git a/dbgpt/model/cluster/worker/default_worker.py b/dbgpt/model/cluster/worker/default_worker.py index c1fe109ea..a28c1ed9a 100644 --- a/dbgpt/model/cluster/worker/default_worker.py +++ b/dbgpt/model/cluster/worker/default_worker.py @@ -201,7 +201,8 @@ async def async_count_token(self, prompt: str) -> int: def get_model_metadata(self, params: Dict) -> ModelMetadata: ext_metadata = ModelExtraMedata( - prompt_sep=self.llm_adapter.get_default_message_separator() + prompt_roles=self.llm_adapter.get_prompt_roles(), + prompt_sep=self.llm_adapter.get_default_message_separator(), ) return ModelMetadata( model=self.model_name, @@ -332,19 +333,25 @@ def _handle_output( ): finish_reason = None usage = None + error_code = 0 if isinstance(output, dict): finish_reason = output.get("finish_reason") usage = output.get("usage") output = output["text"] if finish_reason is not None: logger.info(f"finish_reason: {finish_reason}") + elif isinstance(output, ModelOutput): + finish_reason = output.finish_reason + usage = output.usage + error_code = output.error_code + output = output.text incremental_output = output[len(previous_response) :] print(incremental_output, end="", flush=True) metrics = _new_metrics_from_model_output(last_metrics, is_first_generate, usage) model_output = ModelOutput( text=output, - error_code=0, + error_code=error_code, model_context=model_context, finish_reason=finish_reason, usage=usage, diff --git a/dbgpt/model/llm_out/proxy_llm.py b/dbgpt/model/llm_out/proxy_llm.py index cbe9c42fc..4caa1b6a5 100644 --- a/dbgpt/model/llm_out/proxy_llm.py +++ b/dbgpt/model/llm_out/proxy_llm.py @@ -13,28 +13,28 @@ from dbgpt.model.proxy.llms.wenxin import wenxin_generate_stream from dbgpt.model.proxy.llms.zhipu import zhipu_generate_stream - -def proxyllm_generate_stream( - model: ProxyModel, tokenizer, params, device, context_len=2048 -): - generator_mapping = { - "proxyllm": chatgpt_generate_stream, - "chatgpt_proxyllm": chatgpt_generate_stream, - "bard_proxyllm": bard_generate_stream, - "claude_proxyllm": claude_generate_stream, - # "gpt4_proxyllm": gpt4_generate_stream, move to chatgpt_generate_stream - "wenxin_proxyllm": wenxin_generate_stream, - "tongyi_proxyllm": tongyi_generate_stream, - "zhipu_proxyllm": zhipu_generate_stream, - "gemini_proxyllm": gemini_generate_stream, - "bc_proxyllm": baichuan_generate_stream, - "spark_proxyllm": spark_generate_stream, - } - model_params = model.get_params() - model_name = model_params.model_name - default_error_message = f"{model_name} LLM is not supported" - generator_function = generator_mapping.get( - model_name, lambda *args: [default_error_message] - ) - - yield from generator_function(model, tokenizer, params, device, context_len) +# This has been moved to dbgpt/model/adapter/proxy_adapter.py +# def proxyllm_generate_stream( +# model: ProxyModel, tokenizer, params, device, context_len=2048 +# ): +# generator_mapping = { +# "proxyllm": chatgpt_generate_stream, +# "chatgpt_proxyllm": chatgpt_generate_stream, +# "bard_proxyllm": bard_generate_stream, +# "claude_proxyllm": claude_generate_stream, +# # "gpt4_proxyllm": gpt4_generate_stream, move to chatgpt_generate_stream +# "wenxin_proxyllm": wenxin_generate_stream, +# "tongyi_proxyllm": tongyi_generate_stream, +# "zhipu_proxyllm": zhipu_generate_stream, +# "gemini_proxyllm": gemini_generate_stream, +# "bc_proxyllm": baichuan_generate_stream, +# "spark_proxyllm": spark_generate_stream, +# } +# model_params = model.get_params() +# model_name = model_params.model_name +# default_error_message = f"{model_name} LLM is not supported" +# generator_function = generator_mapping.get( +# model_name, lambda *args: [default_error_message] +# ) +# +# yield from generator_function(model, tokenizer, params, device, context_len) diff --git a/dbgpt/model/loader.py b/dbgpt/model/loader.py index 43d4e27f8..94872a2cb 100644 --- a/dbgpt/model/loader.py +++ b/dbgpt/model/loader.py @@ -126,7 +126,8 @@ def loader_with_params( elif model_type == ModelType.LLAMA_CPP: return llamacpp_loader(llm_adapter, model_params) elif model_type == ModelType.PROXY: - return proxyllm_loader(llm_adapter, model_params) + # return proxyllm_loader(llm_adapter, model_params) + return llm_adapter.load_from_params(model_params) elif model_type == ModelType.VLLM: return llm_adapter.load_from_params(model_params) else: diff --git a/dbgpt/model/operator/llm_operator.py b/dbgpt/model/operator/llm_operator.py index c1d6ef068..b35d042df 100644 --- a/dbgpt/model/operator/llm_operator.py +++ b/dbgpt/model/operator/llm_operator.py @@ -37,7 +37,7 @@ def llm_client(self) -> LLMClient: self._llm_client = DefaultLLMClient(worker_manager_factory.create()) else: if self._default_llm_client is None: - from dbgpt.model.utils.chatgpt_utils import OpenAILLMClient + from dbgpt.model.proxy.llms.chatgpt import OpenAILLMClient self._default_llm_client = OpenAILLMClient() logger.info( diff --git a/dbgpt/model/parameter.py b/dbgpt/model/parameter.py index bcaa46c4d..6cd666a0b 100644 --- a/dbgpt/model/parameter.py +++ b/dbgpt/model/parameter.py @@ -475,6 +475,16 @@ class ProxyModelParameters(BaseModelParameters): max_context_size: Optional[int] = field( default=4096, metadata={"help": "Maximum context size"} ) + llm_client_class: Optional[str] = field( + default=None, + metadata={ + "help": "The class name of llm client, such as dbgpt.model.proxy.llms.proxy_model.ProxyModel" + }, + ) + + def __post_init__(self): + if not self.proxy_server_url and self.proxy_api_base: + self.proxy_server_url = f"{self.proxy_api_base}/chat/completions" @dataclass diff --git a/dbgpt/model/proxy/__init__.py b/dbgpt/model/proxy/__init__.py index e69de29bb..143cc76a5 100644 --- a/dbgpt/model/proxy/__init__.py +++ b/dbgpt/model/proxy/__init__.py @@ -0,0 +1,15 @@ +from dbgpt.model.proxy.llms.chatgpt import OpenAILLMClient +from dbgpt.model.proxy.llms.gemini import GeminiLLMClient +from dbgpt.model.proxy.llms.spark import SparkLLMClient +from dbgpt.model.proxy.llms.tongyi import TongyiLLMClient +from dbgpt.model.proxy.llms.wenxin import WenxinLLMClient +from dbgpt.model.proxy.llms.zhipu import ZhipuLLMClient + +__ALL__ = [ + "OpenAILLMClient", + "GeminiLLMClient", + "TongyiLLMClient", + "ZhipuLLMClient", + "WenxinLLMClient", + "SparkLLMClient", +] diff --git a/dbgpt/model/proxy/base.py b/dbgpt/model/proxy/base.py new file mode 100644 index 000000000..0faec81bd --- /dev/null +++ b/dbgpt/model/proxy/base.py @@ -0,0 +1,242 @@ +from __future__ import annotations + +import logging +from abc import ABC, abstractmethod +from concurrent.futures import Executor, ThreadPoolExecutor +from functools import cache +from typing import TYPE_CHECKING, AsyncIterator, Iterator, List, Optional + +from dbgpt.core import ( + LLMClient, + MessageConverter, + ModelMetadata, + ModelOutput, + ModelRequest, +) +from dbgpt.model.parameter import ProxyModelParameters +from dbgpt.util.executor_utils import blocking_func_to_async + +if TYPE_CHECKING: + from tiktoken import Encoding + +logger = logging.getLogger(__name__) + + +class ProxyTokenizer(ABC): + @abstractmethod + def count_token(self, model_name: str, prompts: List[str]) -> List[int]: + """Count token of given prompts. + Args: + model_name (str): model name + prompts (List[str]): prompts to count token + + Returns: + List[int]: token count, -1 if failed + """ + + +class TiktokenProxyTokenizer(ProxyTokenizer): + def __init__(self): + self._cache = {} + + def count_token(self, model_name: str, prompts: List[str]) -> List[int]: + encoding_model = self._get_or_create_encoding_model(model_name) + if not encoding_model: + return [-1] * len(prompts) + return [ + len(encoding_model.encode(prompt, disallowed_special=())) + for prompt in prompts + ] + + def _get_or_create_encoding_model(self, model_name: str) -> Optional[Encoding]: + if model_name in self._cache: + return self._cache[model_name] + encoding_model = None + try: + import tiktoken + + logger.info( + "tiktoken installed, using it to count tokens, tiktoken will download tokenizer from network, " + "also you can download it and put it in the directory of environment variable TIKTOKEN_CACHE_DIR" + ) + except ImportError: + self._support_encoding = False + logger.warn("tiktoken not installed, cannot count tokens") + return None + try: + if not model_name: + model_name = "gpt-3.5-turbo" + encoding_model = tiktoken.model.encoding_for_model(model_name) + except KeyError: + logger.warning( + f"{model_name}'s tokenizer not found, using cl100k_base encoding." + ) + if encoding_model: + self._cache[model_name] = encoding_model + return encoding_model + + +class ProxyLLMClient(LLMClient): + """Proxy LLM client base class""" + + executor: Executor + model_names: List[str] + + def __init__( + self, + model_names: List[str], + context_length: int = 4096, + executor: Optional[Executor] = None, + proxy_tokenizer: Optional[ProxyTokenizer] = None, + ): + self.model_names = model_names + self.context_length = context_length + self.executor = executor or ThreadPoolExecutor() + self.proxy_tokenizer = proxy_tokenizer or TiktokenProxyTokenizer() + + @classmethod + @abstractmethod + def new_client( + cls, + model_params: ProxyModelParameters, + default_executor: Optional[Executor] = None, + ) -> "ProxyLLMClient": + """Create a new client instance from model parameters. + + Args: + model_params (ProxyModelParameters): model parameters + default_executor (Executor): default executor, If your model is blocking, + you should pass a ThreadPoolExecutor. + """ + + async def generate( + self, + request: ModelRequest, + message_converter: Optional[MessageConverter] = None, + ) -> ModelOutput: + """Generate model output from model request. + + We strongly recommend you to implement this method instead of sync_generate for high performance. + + Args: + request (ModelRequest): model request + message_converter (Optional[MessageConverter], optional): message converter. Defaults to None. + + Returns: + ModelOutput: model output + """ + return await blocking_func_to_async( + self.executor, self.sync_generate, request, message_converter + ) + + def sync_generate( + self, + request: ModelRequest, + message_converter: Optional[MessageConverter] = None, + ) -> ModelOutput: + """Generate model output from model request. + + Args: + request (ModelRequest): model request + message_converter (Optional[MessageConverter], optional): message converter. Defaults to None. + + Returns: + ModelOutput: model output + """ + output = None + for out in self.sync_generate_stream(request, message_converter): + output = out + return output + + async def generate_stream( + self, + request: ModelRequest, + message_converter: Optional[MessageConverter] = None, + ) -> AsyncIterator[ModelOutput]: + """Generate model output stream from model request. + + We strongly recommend you to implement this method instead of sync_generate_stream for high performance. + + Args: + request (ModelRequest): model request + message_converter (Optional[MessageConverter], optional): message converter. Defaults to None. + + Returns: + AsyncIterator[ModelOutput]: model output stream + """ + from starlette.concurrency import iterate_in_threadpool + + async for output in iterate_in_threadpool( + self.sync_generate_stream(request, message_converter) + ): + yield output + + def sync_generate_stream( + self, + request: ModelRequest, + message_converter: Optional[MessageConverter] = None, + ) -> Iterator[ModelOutput]: + """Generate model output stream from model request. + + Args: + request (ModelRequest): model request + message_converter (Optional[MessageConverter], optional): message converter. Defaults to None. + + Returns: + Iterator[ModelOutput]: model output stream + """ + + raise NotImplementedError() + + async def models(self) -> List[ModelMetadata]: + """Get model metadata list + + Returns: + List[ModelMetadata]: model metadata list + """ + return self._models() + + @cache + def _models(self) -> List[ModelMetadata]: + results = [] + for model in self.model_names: + results.append( + ModelMetadata(model=model, context_length=self.context_length) + ) + return results + + def local_covert_message( + self, + request: ModelRequest, + message_converter: Optional[MessageConverter] = None, + ) -> ModelRequest: + """Convert message locally + + Args: + request (ModelRequest): model request + message_converter (Optional[MessageConverter], optional): message converter. Defaults to None. + + Returns: + ModelRequest: converted model request + """ + if not message_converter: + return request + metadata = self._models[0].ext_metadata + new_request = request.copy() + new_messages = message_converter.convert(request.messages, metadata) + new_request.messages = new_messages + return new_request + + async def count_token(self, model: str, prompt: str) -> int: + """Count token of given prompt + + Args: + model (str): model name + prompt (str): prompt to count token + + Returns: + int: token count, -1 if failed + """ + return await blocking_func_to_async( + self.executor, self.proxy_tokenizer.count_token, model, [prompt] + )[0] diff --git a/dbgpt/model/proxy/llms/baichuan.py b/dbgpt/model/proxy/llms/baichuan.py index aed91e151..696d838c5 100644 --- a/dbgpt/model/proxy/llms/baichuan.py +++ b/dbgpt/model/proxy/llms/baichuan.py @@ -13,6 +13,7 @@ def baichuan_generate_stream( model: ProxyModel, tokenizer=None, params=None, device=None, context_len=4096 ): + # TODO: Support new Baichuan ProxyLLMClient url = "https://api.baichuan-ai.com/v1/chat/completions" model_params = model.get_params() diff --git a/dbgpt/model/proxy/llms/bard.py b/dbgpt/model/proxy/llms/bard.py index 6317cad6d..e60091516 100755 --- a/dbgpt/model/proxy/llms/bard.py +++ b/dbgpt/model/proxy/llms/bard.py @@ -9,6 +9,7 @@ def bard_generate_stream( model: ProxyModel, tokenizer, params, device, context_len=2048 ): + # TODO: Support new bard ProxyLLMClient model_params = model.get_params() print(f"Model: {model}, model_params: {model_params}") diff --git a/dbgpt/model/proxy/llms/chatgpt.py b/dbgpt/model/proxy/llms/chatgpt.py index 15932bfcc..0d200953d 100755 --- a/dbgpt/model/proxy/llms/chatgpt.py +++ b/dbgpt/model/proxy/llms/chatgpt.py @@ -1,264 +1,253 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- +from __future__ import annotations import importlib.metadata as metadata import logging -import os -from typing import List - -import httpx - -from dbgpt.core.interface.message import ModelMessage, ModelMessageRoleType +from concurrent.futures import Executor +from typing import TYPE_CHECKING, Any, AsyncIterator, Dict, List, Optional, Union + +from dbgpt.core import ( + MessageConverter, + ModelMetadata, + ModelOutput, + ModelRequest, + ModelRequestContext, +) from dbgpt.model.parameter import ProxyModelParameters +from dbgpt.model.proxy.base import ProxyLLMClient from dbgpt.model.proxy.llms.proxy_model import ProxyModel +from dbgpt.model.utils.chatgpt_utils import OpenAIParameters -logger = logging.getLogger(__name__) +if TYPE_CHECKING: + from httpx._types import ProxiesTypes + from openai import AsyncAzureOpenAI, AsyncOpenAI + ClientType = Union[AsyncAzureOpenAI, AsyncOpenAI] -def _initialize_openai(params: ProxyModelParameters): - try: - import openai - except ImportError as exc: - raise ValueError( - "Could not import python package: openai " - "Please install openai by command `pip install openai` " - ) from exc +logger = logging.getLogger(__name__) - api_type = params.proxy_api_type or os.getenv("OPENAI_API_TYPE", "open_ai") - api_base = params.proxy_api_base or os.getenv( - "OPENAI_API_TYPE", - os.getenv("AZURE_OPENAI_ENDPOINT") if api_type == "azure" else None, - ) - api_key = params.proxy_api_key or os.getenv( - "OPENAI_API_KEY", - os.getenv("AZURE_OPENAI_KEY") if api_type == "azure" else None, +async def chatgpt_generate_stream( + model: ProxyModel, tokenizer, params, device, context_len=2048 +): + client: OpenAILLMClient = model.proxy_llm_client + context = ModelRequestContext(stream=True, user_name=params.get("user_name")) + request = ModelRequest.build_request( + client.default_model, + messages=params["messages"], + temperature=params.get("temperature"), + context=context, + max_new_tokens=params.get("max_new_tokens"), ) - api_version = params.proxy_api_version or os.getenv("OPENAI_API_VERSION") - - if not api_base and params.proxy_server_url: - # Adapt previous proxy_server_url configuration - api_base = params.proxy_server_url.split("/chat/completions")[0] - if api_type: - openai.api_type = api_type - if api_base: - openai.api_base = api_base - if api_key: - openai.api_key = api_key - if api_version: - openai.api_version = api_version - if params.http_proxy: - openai.proxy = params.http_proxy - - openai_params = { - "api_type": api_type, - "api_base": api_base, - "api_version": api_version, - "proxy": params.http_proxy, - } - - return openai_params - - -def _initialize_openai_v1(params: ProxyModelParameters): - try: - from openai import OpenAI - except ImportError as exc: - raise ValueError( - "Could not import python package: openai " - "Please install openai by command `pip install openai" + async for r in client.generate_stream(request): + yield r + + +class OpenAILLMClient(ProxyLLMClient): + def __init__( + self, + api_key: Optional[str] = None, + api_base: Optional[str] = None, + api_type: Optional[str] = None, + api_version: Optional[str] = None, + model: Optional[str] = None, + proxies: Optional["ProxiesTypes"] = None, + timeout: Optional[int] = 240, + model_alias: Optional[str] = "chatgpt_proxyllm", + context_length: Optional[int] = 8192, + openai_client: Optional["ClientType"] = None, + openai_kwargs: Optional[Dict[str, Any]] = None, + **kwargs, + ): + try: + import openai + except ImportError as exc: + raise ValueError( + "Could not import python package: openai " + "Please install openai by command `pip install openai" + ) from exc + self._openai_version = metadata.version("openai") + self._openai_less_then_v1 = not self._openai_version >= "1.0.0" + self._init_params = OpenAIParameters( + api_type=api_type, + api_base=api_base, + api_key=api_key, + api_version=api_version, + proxies=proxies, + full_url=kwargs.get("full_url"), ) - api_type = params.proxy_api_type or os.getenv("OPENAI_API_TYPE", "open_ai") - - base_url = params.proxy_api_base or os.getenv( - "OPENAI_API_BASE", - os.getenv("AZURE_OPENAI_ENDPOINT") if api_type == "azure" else None, - ) - api_key = params.proxy_api_key or os.getenv( - "OPENAI_API_KEY", - os.getenv("AZURE_OPENAI_KEY") if api_type == "azure" else None, - ) - api_version = params.proxy_api_version or os.getenv("OPENAI_API_VERSION") - - if not base_url and params.proxy_server_url: - # Adapt previous proxy_server_url configuration - base_url = params.proxy_server_url.split("/chat/completions")[0] - - proxies = params.http_proxy - openai_params = { - "api_key": api_key, - "base_url": base_url, - } - return openai_params, api_type, api_version, proxies - - -def __convert_2_gpt_messages(messages: List[ModelMessage]): - gpt_messages = [] - last_usr_message = "" - system_messages = [] + self._model = model + self._proxies = proxies + self._timeout = timeout + self._model_alias = model_alias + self._context_length = context_length + self._api_type = api_type + self._client = openai_client + self._openai_kwargs = openai_kwargs or {} + super().__init__(model_names=[model_alias], context_length=context_length) + + @classmethod + def new_client( + cls, + model_params: ProxyModelParameters, + default_executor: Optional[Executor] = None, + ) -> "OpenAILLMClient": + return cls( + api_key=model_params.proxy_api_key, + api_base=model_params.proxy_api_base, + api_type=model_params.proxy_api_type, + api_version=model_params.proxy_api_version, + model=model_params.proxyllm_backend, + proxies=model_params.http_proxy, + model_alias=model_params.model_name, + context_length=max(model_params.max_context_size, 8192), + full_url=model_params.proxy_server_url, + ) - # TODO: We can't change message order in low level - for message in messages: - if message.role == ModelMessageRoleType.HUMAN or message.role == "user": - last_usr_message = message.content - elif message.role == ModelMessageRoleType.SYSTEM: - system_messages.append(message.content) - elif message.role == ModelMessageRoleType.AI or message.role == "assistant": - last_ai_message = message.content - gpt_messages.append({"role": "user", "content": last_usr_message}) - gpt_messages.append({"role": "assistant", "content": last_ai_message}) + @property + def client(self) -> ClientType: + if self._openai_less_then_v1: + raise ValueError( + "Current model (Load by OpenAILLMClient) require openai.__version__>=1.0.0" + ) + if self._client is None: + from dbgpt.model.utils.chatgpt_utils import _build_openai_client - if len(system_messages) > 0: - if len(system_messages) < 2: - gpt_messages.insert(0, {"role": "system", "content": system_messages[0]}) - gpt_messages.append({"role": "user", "content": last_usr_message}) + self._api_type, self._client = _build_openai_client( + init_params=self._init_params + ) + return self._client + + @property + def default_model(self) -> str: + model = self._model + if not model: + model = "gpt-35-turbo" if self._api_type == "azure" else "gpt-3.5-turbo" + return model + + def _build_request( + self, request: ModelRequest, stream: Optional[bool] = False + ) -> Dict[str, Any]: + payload = {"stream": stream} + model = request.model or self.default_model + if self._openai_less_then_v1 and self._api_type == "azure": + payload["engine"] = model else: - gpt_messages.append({"role": "user", "content": system_messages[1]}) - else: - last_message = messages[-1] - if last_message.role == ModelMessageRoleType.HUMAN: - gpt_messages.append({"role": "user", "content": last_message.content}) - - return gpt_messages - - -def _build_request(model: ProxyModel, params): - model_params = model.get_params() - logger.info(f"Model: {model}, model_params: {model_params}") - - messages: List[ModelMessage] = params["messages"] - - # history = __convert_2_gpt_messages(messages) - convert_to_compatible_format = params.get("convert_to_compatible_format", False) - history = ModelMessage.to_openai_messages( - messages, convert_to_compatible_format=convert_to_compatible_format - ) - payloads = { - "temperature": params.get("temperature"), - "max_tokens": params.get("max_new_tokens"), - "stream": True, - } - proxyllm_backend = model_params.proxyllm_backend + payload["model"] = model + # Apply openai kwargs + for k, v in self._openai_kwargs.items(): + payload[k] = v + if request.temperature: + payload["temperature"] = request.temperature + if request.max_new_tokens: + payload["max_tokens"] = request.max_new_tokens + return payload + + async def generate( + self, + request: ModelRequest, + message_converter: Optional[MessageConverter] = None, + ) -> ModelOutput: + request = self.local_covert_message(request, message_converter) + messages = request.to_common_messages() + payload = self._build_request(request) + logger.info( + f"Send request to openai({self._openai_version}), payload: {payload}\n\n messages:\n{messages}" + ) + try: + if self._openai_less_then_v1: + return await self.generate_less_then_v1(messages, payload) + else: + return await self.generate_v1(messages, payload) + except Exception as e: + return ModelOutput( + text=f"**LLMServer Generate Error, Please CheckErrorInfo.**: {e}", + error_code=1, + ) - if metadata.version("openai") >= "1.0.0": - openai_params, api_type, api_version, proxies = _initialize_openai_v1( - model_params + async def generate_stream( + self, + request: ModelRequest, + message_converter: Optional[MessageConverter] = None, + ) -> AsyncIterator[ModelOutput]: + request = self.local_covert_message(request, message_converter) + messages = request.to_common_messages() + payload = self._build_request(request, stream=True) + logger.info( + f"Send request to openai({self._openai_version}), payload: {payload}\n\n messages:\n{messages}" ) - proxyllm_backend = proxyllm_backend or "gpt-3.5-turbo" - payloads["model"] = proxyllm_backend - else: - openai_params = _initialize_openai(model_params) - if openai_params["api_type"] == "azure": - # engine = "deployment_name". - proxyllm_backend = proxyllm_backend or "gpt-35-turbo" - payloads["engine"] = proxyllm_backend + if self._openai_less_then_v1: + async for r in self.generate_stream_less_then_v1(messages, payload): + yield r else: - proxyllm_backend = proxyllm_backend or "gpt-3.5-turbo" - payloads["model"] = proxyllm_backend - - logger.info(f"Send request to real model {proxyllm_backend}") - return history, payloads - - -def chatgpt_generate_stream( - model: ProxyModel, tokenizer, params, device, context_len=2048 -): - if metadata.version("openai") >= "1.0.0": - model_params = model.get_params() - openai_params, api_type, api_version, proxies = _initialize_openai_v1( - model_params + async for r in self.generate_stream_v1(messages, payload): + yield r + + async def generate_v1( + self, messages: List[Dict[str, Any]], payload: Dict[str, Any] + ) -> ModelOutput: + chat_completion = await self.client.chat.completions.create( + messages=messages, **payload ) - history, payloads = _build_request(model, params) - if api_type == "azure": - from openai import AzureOpenAI + text = chat_completion.choices[0].message.content + usage = chat_completion.usage.dict() + return ModelOutput(text=text, error_code=0, usage=usage) - client = AzureOpenAI( - api_key=openai_params["api_key"], - api_version=api_version, - azure_endpoint=openai_params["base_url"], - http_client=httpx.Client(proxies=proxies), - ) - else: - from openai import OpenAI + async def generate_less_then_v1( + self, messages: List[Dict[str, Any]], payload: Dict[str, Any] + ) -> ModelOutput: + import openai - client = OpenAI(**openai_params, http_client=httpx.Client(proxies=proxies)) - res = client.chat.completions.create(messages=history, **payloads) + chat_completion = await openai.ChatCompletion.acreate( + messages=messages, **payload + ) + text = chat_completion.choices[0].message.content + usage = chat_completion.usage.to_dict() + return ModelOutput(text=text, error_code=0, usage=usage) + + async def generate_stream_v1( + self, messages: List[Dict[str, Any]], payload: Dict[str, Any] + ) -> AsyncIterator[ModelOutput]: + chat_completion = await self.client.chat.completions.create( + messages=messages, **payload + ) text = "" - for r in res: - # logger.info(str(r)) - # Azure Openai reponse may have empty choices body in the first chunk - # to avoid index out of range error + async for r in chat_completion: if len(r.choices) == 0: continue if r.choices[0].delta.content is not None: content = r.choices[0].delta.content text += content - yield text + yield ModelOutput(text=text, error_code=0) - else: + async def generate_stream_less_then_v1( + self, messages: List[Dict[str, Any]], payload: Dict[str, Any] + ) -> AsyncIterator[ModelOutput]: import openai - history, payloads = _build_request(model, params) - - res = openai.ChatCompletion.create(messages=history, **payloads) - + res = await openai.ChatCompletion.acreate(messages=messages, **payload) text = "" - for r in res: - if len(r.choices) == 0: + async for r in res: + if not r.get("choices"): continue if r["choices"][0]["delta"].get("content") is not None: content = r["choices"][0]["delta"]["content"] text += content - yield text - + yield ModelOutput(text=text, error_code=0) -async def async_chatgpt_generate_stream( - model: ProxyModel, tokenizer, params, device, context_len=2048 -): - if metadata.version("openai") >= "1.0.0": - model_params = model.get_params() - openai_params, api_type, api_version, proxies = _initialize_openai_v1( - model_params + async def models(self) -> List[ModelMetadata]: + model_metadata = ModelMetadata( + model=self._model_alias, + context_length=await self.get_context_length(), ) - history, payloads = _build_request(model, params) - if api_type == "azure": - from openai import AsyncAzureOpenAI + return [model_metadata] - client = AsyncAzureOpenAI( - api_key=openai_params["api_key"], - api_version=api_version, - azure_endpoint=openai_params["base_url"], - http_client=httpx.AsyncClient(proxies=proxies), - ) - else: - from openai import AsyncOpenAI - - client = AsyncOpenAI( - **openai_params, http_client=httpx.AsyncClient(proxies=proxies) - ) + async def get_context_length(self) -> int: + """Get the context length of the model. - res = await client.chat.completions.create(messages=history, **payloads) - text = "" - for r in res: - if not r.get("choices"): - continue - if r.choices[0].delta.content is not None: - content = r.choices[0].delta.content - text += content - yield text - else: - import openai - - history, payloads = _build_request(model, params) - - res = await openai.ChatCompletion.acreate(messages=history, **payloads) - - text = "" - async for r in res: - if not r.get("choices"): - continue - if r["choices"][0]["delta"].get("content") is not None: - content = r["choices"][0]["delta"]["content"] - text += content - yield text + Returns: + int: The context length. + # TODO: This is a temporary solution. We should have a better way to get the context length. + eg. get real context length from the openai api. + """ + return self._context_length diff --git a/dbgpt/model/proxy/llms/gemini.py b/dbgpt/model/proxy/llms/gemini.py index 9a3b3b868..87af2f1a2 100644 --- a/dbgpt/model/proxy/llms/gemini.py +++ b/dbgpt/model/proxy/llms/gemini.py @@ -1,72 +1,54 @@ -from typing import Any, Dict, List, Tuple - -from dbgpt.core.interface.message import ModelMessage, parse_model_messages +import os +from concurrent.futures import Executor +from typing import Any, Dict, Iterator, List, Optional, Tuple + +from dbgpt.core import ( + MessageConverter, + ModelMessage, + ModelOutput, + ModelRequest, + ModelRequestContext, +) +from dbgpt.core.interface.message import parse_model_messages +from dbgpt.model.parameter import ProxyModelParameters +from dbgpt.model.proxy.base import ProxyLLMClient from dbgpt.model.proxy.llms.proxy_model import ProxyModel GEMINI_DEFAULT_MODEL = "gemini-pro" +safety_settings = [ + {"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_MEDIUM_AND_ABOVE"}, + { + "category": "HARM_CATEGORY_HATE_SPEECH", + "threshold": "BLOCK_MEDIUM_AND_ABOVE", + }, + { + "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", + "threshold": "BLOCK_MEDIUM_AND_ABOVE", + }, + { + "category": "HARM_CATEGORY_DANGEROUS_CONTENT", + "threshold": "BLOCK_MEDIUM_AND_ABOVE", + }, +] + def gemini_generate_stream( model: ProxyModel, tokenizer, params, device, context_len=2048 ): - """Zhipu ai, see: https://open.bigmodel.cn/dev/api#overview""" model_params = model.get_params() print(f"Model: {model}, model_params: {model_params}") - - # TODO proxy model use unified config? - proxy_api_key = model_params.proxy_api_key - proxyllm_backend = GEMINI_DEFAULT_MODEL or model_params.proxyllm_backend - - generation_config = { - "temperature": 0.7, - "top_p": 1, - "top_k": 1, - "max_output_tokens": 2048, - } - - safety_settings = [ - {"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_MEDIUM_AND_ABOVE"}, - { - "category": "HARM_CATEGORY_HATE_SPEECH", - "threshold": "BLOCK_MEDIUM_AND_ABOVE", - }, - { - "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", - "threshold": "BLOCK_MEDIUM_AND_ABOVE", - }, - { - "category": "HARM_CATEGORY_DANGEROUS_CONTENT", - "threshold": "BLOCK_MEDIUM_AND_ABOVE", - }, - ] - - import google.generativeai as genai - - if model_params.proxy_api_base: - from google.api_core import client_options - - client_opts = client_options.ClientOptions( - api_endpoint=model_params.proxy_api_base - ) - genai.configure( - api_key=proxy_api_key, transport="rest", client_options=client_opts - ) - else: - genai.configure(api_key=proxy_api_key) - model = genai.GenerativeModel( - model_name=proxyllm_backend, - generation_config=generation_config, - safety_settings=safety_settings, + client: GeminiLLMClient = model.proxy_llm_client + context = ModelRequestContext(stream=True, user_name=params.get("user_name")) + request = ModelRequest.build_request( + client.default_model, + messages=params["messages"], + temperature=params.get("temperature"), + context=context, + max_new_tokens=params.get("max_new_tokens"), ) - messages: List[ModelMessage] = params["messages"] - user_prompt, gemini_hist = _transform_to_gemini_messages(messages) - chat = model.start_chat(history=gemini_hist) - response = chat.send_message(user_prompt, stream=True) - text = "" - for chunk in response: - text += chunk.text - print(text) - yield text + for r in client.sync_generate_stream(request): + yield r def _transform_to_gemini_messages( @@ -97,12 +79,104 @@ def _transform_to_gemini_messages( {"role": "model", "parts": {"text": "Hi there!"}}, ] """ + # TODO raise error if messages has system message user_prompt, system_messages, history_messages = parse_model_messages(messages) if system_messages: - user_prompt = "".join(system_messages) + "\n" + user_prompt + raise ValueError("Gemini does not support system role") gemini_hist = [] if history_messages: for user_message, model_message in history_messages: gemini_hist.append({"role": "user", "parts": {"text": user_message}}) gemini_hist.append({"role": "model", "parts": {"text": model_message}}) return user_prompt, gemini_hist + + +class GeminiLLMClient(ProxyLLMClient): + def __init__( + self, + model: Optional[str] = None, + api_key: Optional[str] = None, + api_base: Optional[str] = None, + model_alias: Optional[str] = "gemini_proxyllm", + context_length: Optional[int] = 8192, + executor: Optional[Executor] = None, + ): + try: + import google.generativeai as genai + + except ImportError as exc: + raise ValueError( + "Could not import python package: generativeai " + "Please install dashscope by command `pip install google-generativeai" + ) from exc + if not model: + model = GEMINI_DEFAULT_MODEL + self._api_key = api_key if api_key else os.getenv("GEMINI_PROXY_API_KEY") + self._api_base = api_base if api_base else os.getenv("GEMINI_PROXY_API_BASE") + self._model = model + self.default_model = self._model + if not self._api_key: + raise RuntimeError("api_key can't be empty") + + if self._api_base: + from google.api_core import client_options + + client_opts = client_options.ClientOptions(api_endpoint=self._api_base) + genai.configure( + api_key=self._api_key, transport="rest", client_options=client_opts + ) + else: + genai.configure(api_key=self._api_key) + super().__init__( + model_names=[model, model_alias], + context_length=context_length, + executor=executor, + ) + + @classmethod + def new_client( + cls, + model_params: ProxyModelParameters, + default_executor: Optional[Executor] = None, + ) -> "GeminiLLMClient": + return cls( + model=model_params.proxyllm_backend, + api_key=model_params.proxy_api_key, + api_base=model_params.proxy_api_base, + model_alias=model_params.model_name, + context_length=model_params.max_context_size, + executor=default_executor, + ) + + def sync_generate_stream( + self, + request: ModelRequest, + message_converter: Optional[MessageConverter] = None, + ) -> Iterator[ModelOutput]: + request = self.local_covert_message(request, message_converter) + try: + import google.generativeai as genai + + generation_config = { + "temperature": request.temperature, + "top_p": 1, + "top_k": 1, + "max_output_tokens": request.max_new_tokens, + } + model = genai.GenerativeModel( + model_name=self._model, + generation_config=generation_config, + safety_settings=safety_settings, + ) + user_prompt, gemini_hist = _transform_to_gemini_messages(request.messages) + chat = model.start_chat(history=gemini_hist) + response = chat.send_message(user_prompt, stream=True) + text = "" + for chunk in response: + text += chunk.text + yield ModelOutput(text=text, error_code=0) + except Exception as e: + return ModelOutput( + text=f"**LLMServer Generate Error, Please CheckErrorInfo.**: {e}", + error_code=1, + ) diff --git a/dbgpt/model/proxy/llms/proxy_model.py b/dbgpt/model/proxy/llms/proxy_model.py index b287ea88f..3ee3c67fd 100644 --- a/dbgpt/model/proxy/llms/proxy_model.py +++ b/dbgpt/model/proxy/llms/proxy_model.py @@ -4,6 +4,7 @@ from typing import TYPE_CHECKING, List, Optional, Union from dbgpt.model.parameter import ProxyModelParameters +from dbgpt.model.proxy.base import ProxyLLMClient from dbgpt.model.utils.token_utils import ProxyTokenizerWrapper if TYPE_CHECKING: @@ -13,9 +14,14 @@ class ProxyModel: - def __init__(self, model_params: ProxyModelParameters) -> None: + def __init__( + self, + model_params: ProxyModelParameters, + proxy_llm_client: Optional[ProxyLLMClient] = None, + ) -> None: self._model_params = model_params self._tokenizer = ProxyTokenizerWrapper() + self.proxy_llm_client = proxy_llm_client def get_params(self) -> ProxyModelParameters: return self._model_params diff --git a/dbgpt/model/proxy/llms/spark.py b/dbgpt/model/proxy/llms/spark.py index 81a7cc2b4..57bb9f906 100644 --- a/dbgpt/model/proxy/llms/spark.py +++ b/dbgpt/model/proxy/llms/spark.py @@ -2,15 +2,19 @@ import hashlib import hmac import json +import os +from concurrent.futures import Executor from datetime import datetime from time import mktime -from typing import List +from typing import Iterator, Optional from urllib.parse import urlencode, urlparse from wsgiref.handlers import format_date_time from websockets.sync.client import connect -from dbgpt.core.interface.message import ModelMessage, ModelMessageRoleType +from dbgpt.core import MessageConverter, ModelOutput, ModelRequest, ModelRequestContext +from dbgpt.model.parameter import ProxyModelParameters +from dbgpt.model.proxy.base import ProxyLLMClient from dbgpt.model.proxy.llms.proxy_model import ProxyModel SPARK_DEFAULT_API_VERSION = "v3" @@ -34,63 +38,21 @@ def checklen(text): def spark_generate_stream( model: ProxyModel, tokenizer, params, device, context_len=2048 ): - model_params = model.get_params() - proxy_api_version = model_params.proxyllm_backend or SPARK_DEFAULT_API_VERSION - proxy_api_key = model_params.proxy_api_key - proxy_api_secret = model_params.proxy_api_secret - proxy_app_id = model_params.proxy_api_app_id - - if proxy_api_version == SPARK_DEFAULT_API_VERSION: - url = "ws://spark-api.xf-yun.com/v3.1/chat" - domain = "generalv3" - else: - url = "ws://spark-api.xf-yun.com/v2.1/chat" - domain = "generalv2" - - messages: List[ModelMessage] = params["messages"] - - last_user_input = None - for index in range(len(messages) - 1, -1, -1): - print(f"index: {index}") - if messages[index].role == ModelMessageRoleType.HUMAN: - last_user_input = {"role": "user", "content": messages[index].content} - del messages[index] - break - - # TODO: Support convert_to_compatible_format config - convert_to_compatible_format = params.get("convert_to_compatible_format", False) - - history = [] - # Add history conversation - for message in messages: - # There is no role for system in spark LLM - if message.role == ModelMessageRoleType.HUMAN or ModelMessageRoleType.SYSTEM: - history.append({"role": "user", "content": message.content}) - elif message.role == ModelMessageRoleType.AI: - history.append({"role": "assistant", "content": message.content}) - else: - pass - - question = checklen(history + [last_user_input]) - - print('last_user_input.get("content")', last_user_input.get("content")) - data = { - "header": {"app_id": proxy_app_id, "uid": str(params.get("request_id", 1))}, - "parameter": { - "chat": { - "domain": domain, - "random_threshold": 0.5, - "max_tokens": context_len, - "auditing": "default", - "temperature": params.get("temperature"), - } - }, - "payload": {"message": {"text": question}}, - } - - spark_api = SparkAPI(proxy_app_id, proxy_api_key, proxy_api_secret, url) - request_url = spark_api.gen_url() - return get_response(request_url, data) + client: SparkLLMClient = model.proxy_llm_client + context = ModelRequestContext( + stream=True, + user_name=params.get("user_name"), + request_id=params.get("request_id"), + ) + request = ModelRequest.build_request( + client.default_model, + messages=params["messages"], + temperature=params.get("temperature"), + context=context, + max_new_tokens=params.get("max_new_tokens"), + ) + for r in client.sync_generate_stream(request): + yield r def get_response(request_url, data): @@ -107,8 +69,8 @@ def get_response(request_url, data): result += text[0]["content"] if choices.get("status") == 2: break - except Exception: - break + except Exception as e: + raise e yield result @@ -155,3 +117,103 @@ def gen_url(self): url = self.spark_url + "?" + urlencode(v) # 此处打印出建立连接时候的url,参考本demo的时候可取消上方打印的注释,比对相同参数时生成的url与自己代码生成的url是否一致 return url + + +class SparkLLMClient(ProxyLLMClient): + def __init__( + self, + model: Optional[str] = None, + app_id: Optional[str] = None, + api_key: Optional[str] = None, + api_secret: Optional[str] = None, + api_base: Optional[str] = None, + api_domain: Optional[str] = None, + model_version: Optional[str] = None, + model_alias: Optional[str] = "spark_proxyllm", + context_length: Optional[int] = 4096, + executor: Optional[Executor] = None, + ): + if not model_version: + model_version = model or os.getenv("XUNFEI_SPARK_API_VERSION") + if not api_base: + if model_version == SPARK_DEFAULT_API_VERSION: + api_base = "ws://spark-api.xf-yun.com/v3.1/chat" + domain = "generalv3" + else: + api_base = "ws://spark-api.xf-yun.com/v2.1/chat" + domain = "generalv2" + if not api_domain: + api_domain = domain + self._model = model + self.default_model = self._model + self._model_version = model_version + self._api_base = api_base + self._domain = api_domain + self._app_id = app_id or os.getenv("XUNFEI_SPARK_APPID") + self._api_secret = api_secret or os.getenv("XUNFEI_SPARK_API_SECRET") + self._api_key = api_key or os.getenv("XUNFEI_SPARK_API_KEY") + + if not self._app_id: + raise ValueError("app_id can't be empty") + if not self._api_key: + raise ValueError("api_key can't be empty") + if not self._api_secret: + raise ValueError("api_secret can't be empty") + + super().__init__( + model_names=[model, model_alias], + context_length=context_length, + executor=executor, + ) + + @classmethod + def new_client( + cls, + model_params: ProxyModelParameters, + default_executor: Optional[Executor] = None, + ) -> "SparkLLMClient": + return cls( + model=model_params.proxyllm_backend, + app_id=model_params.proxy_api_app_id, + api_key=model_params.proxy_api_key, + api_secret=model_params.proxy_api_secret, + api_base=model_params.proxy_api_base, + model_alias=model_params.model_name, + context_length=model_params.max_context_size, + executor=default_executor, + ) + + def sync_generate_stream( + self, + request: ModelRequest, + message_converter: Optional[MessageConverter] = None, + ) -> Iterator[ModelOutput]: + request = self.local_covert_message(request, message_converter) + messages = request.to_common_messages(support_system_role=False) + request_id = request.context.request_id or "1" + data = { + "header": {"app_id": self._app_id, "uid": request_id}, + "parameter": { + "chat": { + "domain": self._domain, + "random_threshold": 0.5, + "max_tokens": request.max_new_tokens, + "auditing": "default", + "temperature": request.temperature, + } + }, + "payload": {"message": {"text": messages}}, + } + + spark_api = SparkAPI( + self._app_id, self._api_key, self._api_secret, self._api_base + ) + request_url = spark_api.gen_url() + try: + for text in get_response(request_url, data): + yield ModelOutput(text=text, error_code=0) + except Exception as e: + return ModelOutput( + text=f"**LLMServer Generate Error, Please CheckErrorInfo.**: {e}", + error_code=1, + ) diff --git a/dbgpt/model/proxy/llms/tongyi.py b/dbgpt/model/proxy/llms/tongyi.py index e5d008de1..a657ab160 100644 --- a/dbgpt/model/proxy/llms/tongyi.py +++ b/dbgpt/model/proxy/llms/tongyi.py @@ -1,79 +1,109 @@ import logging -from typing import List +from concurrent.futures import Executor +from typing import Iterator, Optional -from dbgpt.core.interface.message import ModelMessage, ModelMessageRoleType +from dbgpt.core import MessageConverter, ModelOutput, ModelRequest, ModelRequestContext +from dbgpt.model.parameter import ProxyModelParameters +from dbgpt.model.proxy.base import ProxyLLMClient from dbgpt.model.proxy.llms.proxy_model import ProxyModel logger = logging.getLogger(__name__) -def __convert_2_tongyi_messages(messages: List[ModelMessage]): - chat_round = 0 - tongyi_messages = [] - - last_usr_message = "" - system_messages = [] - - for message in messages: - if message.role == ModelMessageRoleType.HUMAN: - last_usr_message = message.content - elif message.role == ModelMessageRoleType.SYSTEM: - system_messages.append(message.content) - elif message.role == ModelMessageRoleType.AI: - last_ai_message = message.content - tongyi_messages.append({"role": "user", "content": last_usr_message}) - tongyi_messages.append({"role": "assistant", "content": last_ai_message}) - if len(system_messages) > 0: - if len(system_messages) < 2: - tongyi_messages.insert(0, {"role": "system", "content": system_messages[0]}) - tongyi_messages.append({"role": "user", "content": last_usr_message}) - else: - tongyi_messages.append({"role": "user", "content": system_messages[1]}) - else: - last_message = messages[-1] - if last_message.role == ModelMessageRoleType.HUMAN: - tongyi_messages.append({"role": "user", "content": last_message.content}) - - return tongyi_messages - - def tongyi_generate_stream( model: ProxyModel, tokenizer, params, device, context_len=2048 ): - import dashscope - from dashscope import Generation + client: TongyiLLMClient = model.proxy_llm_client + context = ModelRequestContext(stream=True, user_name=params.get("user_name")) + request = ModelRequest.build_request( + client.default_model, + messages=params["messages"], + temperature=params.get("temperature"), + context=context, + max_new_tokens=params.get("max_new_tokens"), + ) + for r in client.sync_generate_stream(request): + yield r - model_params = model.get_params() - print(f"Model: {model}, model_params: {model_params}") - proxy_api_key = model_params.proxy_api_key - dashscope.api_key = proxy_api_key +class TongyiLLMClient(ProxyLLMClient): + def __init__( + self, + model: Optional[str] = None, + api_key: Optional[str] = None, + api_region: Optional[str] = None, + model_alias: Optional[str] = "tongyi_proxyllm", + context_length: Optional[int] = 4096, + executor: Optional[Executor] = None, + ): + try: + import dashscope + from dashscope import Generation + except ImportError as exc: + raise ValueError( + "Could not import python package: dashscope " + "Please install dashscope by command `pip install dashscope" + ) from exc + if not model: + model = Generation.Models.qwen_turbo + if api_key: + dashscope.api_key = api_key + if api_region: + dashscope.api_region = api_region + self._model = model + self.default_model = self._model - proxyllm_backend = model_params.proxyllm_backend - if not proxyllm_backend: - proxyllm_backend = Generation.Models.qwen_turbo # By Default qwen_turbo + super().__init__( + model_names=[model, model_alias], + context_length=context_length, + executor=executor, + ) - messages: List[ModelMessage] = params["messages"] - convert_to_compatible_format = params.get("convert_to_compatible_format", False) + @classmethod + def new_client( + cls, + model_params: ProxyModelParameters, + default_executor: Optional[Executor] = None, + ) -> "TongyiLLMClient": + return cls( + model=model_params.proxyllm_backend, + api_key=model_params.proxy_api_key, + model_alias=model_params.model_name, + context_length=model_params.max_context_size, + executor=default_executor, + ) - if convert_to_compatible_format: - history = __convert_2_tongyi_messages(messages) - else: - history = ModelMessage.to_openai_messages(messages) - gen = Generation() - res = gen.call( - proxyllm_backend, - messages=history, - top_p=params.get("top_p", 0.8), - stream=True, - result_format="message", - ) + def sync_generate_stream( + self, + request: ModelRequest, + message_converter: Optional[MessageConverter] = None, + ) -> Iterator[ModelOutput]: + from dashscope import Generation + + request = self.local_covert_message(request, message_converter) + + messages = request.to_common_messages() - for r in res: - if r: - if r["status_code"] == 200: - content = r["output"]["choices"][0]["message"].get("content") - yield content - else: - content = r["code"] + ":" + r["message"] - yield content + model = request.model or self._model + try: + gen = Generation() + res = gen.call( + model, + messages=messages, + top_p=0.8, + stream=True, + result_format="message", + ) + for r in res: + if r: + if r["status_code"] == 200: + content = r["output"]["choices"][0]["message"].get("content") + yield ModelOutput(text=content, error_code=0) + else: + content = r["code"] + ":" + r["message"] + yield ModelOutput(text=content, error_code=-1) + except Exception as e: + return ModelOutput( + text=f"**LLMServer Generate Error, Please CheckErrorInfo.**: {e}", + error_code=1, + ) diff --git a/dbgpt/model/proxy/llms/wenxin.py b/dbgpt/model/proxy/llms/wenxin.py index 74e44fa80..0197b8fa6 100644 --- a/dbgpt/model/proxy/llms/wenxin.py +++ b/dbgpt/model/proxy/llms/wenxin.py @@ -1,12 +1,36 @@ import json -from typing import List +import logging +import os +from concurrent.futures import Executor +from typing import Iterator, List, Optional import requests from cachetools import TTLCache, cached -from dbgpt.core.interface.message import ModelMessage, ModelMessageRoleType +from dbgpt.core import ( + MessageConverter, + ModelMessage, + ModelMessageRoleType, + ModelOutput, + ModelRequest, + ModelRequestContext, +) +from dbgpt.model.parameter import ProxyModelParameters +from dbgpt.model.proxy.base import ProxyLLMClient from dbgpt.model.proxy.llms.proxy_model import ProxyModel +# https://cloud.baidu.com/doc/WENXINWORKSHOP/s/clntwmv7t +MODEL_VERSION_MAPPING = { + "ERNIE-Bot-4.0": "completions_pro", + "ERNIE-Bot-8K": "ernie_bot_8k", + "ERNIE-Bot": "completions", + "ERNIE-Bot-turbo": "eb-instant", +} + +_DEFAULT_MODEL = "ERNIE-Bot" + +logger = logging.getLogger(__name__) + @cached(TTLCache(1, 1800)) def _build_access_token(api_key: str, secret_key: str) -> str: @@ -49,94 +73,128 @@ def _to_wenxin_messages(messages: List[ModelMessage]): return wenxin_messages, str_system_message -def __convert_2_wenxin_messages(messages: List[ModelMessage]): - wenxin_messages = [] - - last_usr_message = "" - system_messages = [] - - for message in messages: - if message.role == ModelMessageRoleType.HUMAN: - last_usr_message = message.content - elif message.role == ModelMessageRoleType.SYSTEM: - system_messages.append(message.content) - elif message.role == ModelMessageRoleType.AI: - last_ai_message = message.content - wenxin_messages.append({"role": "user", "content": last_usr_message}) - wenxin_messages.append({"role": "assistant", "content": last_ai_message}) - - # build last user messge - - if len(system_messages) > 0: - if len(system_messages) > 1: - end_message = system_messages[-1] - else: - last_message = messages[-1] - if last_message.role == ModelMessageRoleType.HUMAN: - end_message = system_messages[-1] + "\n" + last_message.content - else: - end_message = system_messages[-1] - else: - last_message = messages[-1] - end_message = last_message.content - wenxin_messages.append({"role": "user", "content": end_message}) - str_system_message = system_messages[0] if len(system_messages) > 0 else "" - return wenxin_messages, str_system_message - - def wenxin_generate_stream( model: ProxyModel, tokenizer, params, device, context_len=2048 ): - MODEL_VERSION = { - "ERNIE-Bot": "completions", - "ERNIE-Bot-turbo": "eb-instant", - } - - model_params = model.get_params() - model_name = model_params.proxyllm_backend - model_version = MODEL_VERSION.get(model_name) - if not model_version: - yield f"Unsupport model version {model_name}" - - proxy_api_key = model_params.proxy_api_key - proxy_api_secret = model_params.proxy_api_secret - access_token = _build_access_token(proxy_api_key, proxy_api_secret) - - headers = {"Content-Type": "application/json", "Accept": "application/json"} - - proxy_server_url = f"https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/{model_version}?access_token={access_token}" - - if not access_token: - yield "Failed to get access token. please set the correct api_key and secret key." - - messages: List[ModelMessage] = params["messages"] - - convert_to_compatible_format = params.get("convert_to_compatible_format", False) - if convert_to_compatible_format: - history, system_message = __convert_2_wenxin_messages(messages) - else: - history, system_message = _to_wenxin_messages(messages) - payload = { - "messages": history, - "system": system_message, - "temperature": params.get("temperature"), - "stream": True, - } - - text = "" - res = requests.post(proxy_server_url, headers=headers, json=payload, stream=True) - print(f"Send request to {proxy_server_url} with real model {model_name}") - for line in res.iter_lines(): - if line: - if not line.startswith(b"data: "): - error_message = line.decode("utf-8") - yield error_message + client: WenxinLLMClient = model.proxy_llm_client + context = ModelRequestContext(stream=True, user_name=params.get("user_name")) + request = ModelRequest.build_request( + client.default_model, + messages=params["messages"], + temperature=params.get("temperature"), + context=context, + max_new_tokens=params.get("max_new_tokens"), + ) + for r in client.sync_generate_stream(request): + yield r + + +class WenxinLLMClient(ProxyLLMClient): + def __init__( + self, + model: Optional[str] = None, + api_key: Optional[str] = None, + api_secret: Optional[str] = None, + model_version: Optional[str] = None, + model_alias: Optional[str] = "wenxin_proxyllm", + context_length: Optional[int] = 8192, + executor: Optional[Executor] = None, + ): + if not model: + model = _DEFAULT_MODEL + if not api_key: + api_key = os.getenv("WEN_XIN_API_KEY") + if not api_secret: + api_secret = os.getenv("WEN_XIN_API_SECRET") + if not model_version: + if model: + model_version = MODEL_VERSION_MAPPING.get(model) else: - json_data = line.split(b": ", 1)[1] - decoded_line = json_data.decode("utf-8") - if decoded_line.lower() != "[DONE]".lower(): - obj = json.loads(json_data) - if obj["result"] is not None: - content = obj["result"] - text += content - yield text + model_version = os.getenv("WEN_XIN_MODEL_VERSION") + if not api_key: + raise ValueError("api_key can't be empty") + if not api_secret: + raise ValueError("api_secret can't be empty") + if not model_version: + raise ValueError("model_version can't be empty") + self._model = model + self._api_key = api_key + self._api_secret = api_secret + self._model_version = model_version + self.default_model = self._model + + super().__init__( + model_names=[model, model_alias], + context_length=context_length, + executor=executor, + ) + + @classmethod + def new_client( + cls, + model_params: ProxyModelParameters, + default_executor: Optional[Executor] = None, + ) -> "WenxinLLMClient": + return cls( + model=model_params.proxyllm_backend, + api_key=model_params.proxy_api_key, + api_secret=model_params.proxy_api_secret, + model_version=model_params.proxy_api_version, + model_alias=model_params.model_name, + context_length=model_params.max_context_size, + executor=default_executor, + ) + + def sync_generate_stream( + self, + request: ModelRequest, + message_converter: Optional[MessageConverter] = None, + ) -> Iterator[ModelOutput]: + request = self.local_covert_message(request, message_converter) + + try: + access_token = _build_access_token(self._api_key, self._api_secret) + + headers = {"Content-Type": "application/json", "Accept": "application/json"} + + proxy_server_url = f"https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/{self._model_version}?access_token={access_token}" + + if not access_token: + raise RuntimeError( + "Failed to get access token. please set the correct api_key and secret key." + ) + + history, system_message = _to_wenxin_messages(request.get_messages()) + payload = { + "messages": history, + "system": system_message, + "temperature": request.temperature, + "stream": True, + } + + text = "" + res = requests.post( + proxy_server_url, headers=headers, json=payload, stream=True + ) + logger.info( + f"Send request to {proxy_server_url} with real model {self._model}, model version {self._model_version}" + ) + for line in res.iter_lines(): + if line: + if not line.startswith(b"data: "): + error_message = line.decode("utf-8") + yield ModelOutput(text=error_message, error_code=1) + else: + json_data = line.split(b": ", 1)[1] + decoded_line = json_data.decode("utf-8") + if decoded_line.lower() != "[DONE]".lower(): + obj = json.loads(json_data) + if obj["result"] is not None: + content = obj["result"] + text += content + yield ModelOutput(text=text, error_code=0) + except Exception as e: + return ModelOutput( + text=f"**LLMServer Generate Error, Please CheckErrorInfo.**: {e}", + error_code=1, + ) diff --git a/dbgpt/model/proxy/llms/zhipu.py b/dbgpt/model/proxy/llms/zhipu.py index 8108ad5e0..80974ae16 100644 --- a/dbgpt/model/proxy/llms/zhipu.py +++ b/dbgpt/model/proxy/llms/zhipu.py @@ -1,46 +1,14 @@ -from typing import List +from concurrent.futures import Executor +from typing import Iterator, Optional -from dbgpt.core.interface.message import ModelMessage, ModelMessageRoleType +from dbgpt.core import MessageConverter, ModelOutput, ModelRequest, ModelRequestContext +from dbgpt.model.parameter import ProxyModelParameters +from dbgpt.model.proxy.base import ProxyLLMClient from dbgpt.model.proxy.llms.proxy_model import ProxyModel CHATGLM_DEFAULT_MODEL = "chatglm_pro" -def __convert_2_zhipu_messages(messages: List[ModelMessage]): - chat_round = 0 - wenxin_messages = [] - - last_usr_message = "" - system_messages = [] - - for message in messages: - if message.role == ModelMessageRoleType.HUMAN: - last_usr_message = message.content - elif message.role == ModelMessageRoleType.SYSTEM: - system_messages.append(message.content) - elif message.role == ModelMessageRoleType.AI: - last_ai_message = message.content - wenxin_messages.append({"role": "user", "content": last_usr_message}) - wenxin_messages.append({"role": "assistant", "content": last_ai_message}) - - # build last user messge - - if len(system_messages) > 0: - if len(system_messages) > 1: - end_message = system_messages[-1] - else: - last_message = messages[-1] - if last_message.role == ModelMessageRoleType.HUMAN: - end_message = system_messages[-1] + "\n" + last_message.content - else: - end_message = system_messages[-1] - else: - last_message = messages[-1] - end_message = last_message.content - wenxin_messages.append({"role": "user", "content": end_message}) - return wenxin_messages, system_messages - - def zhipu_generate_stream( model: ProxyModel, tokenizer, params, device, context_len=2048 ): @@ -48,27 +16,93 @@ def zhipu_generate_stream( model_params = model.get_params() print(f"Model: {model}, model_params: {model_params}") - # TODO proxy model use unified config? - proxy_api_key = model_params.proxy_api_key - proxyllm_backend = CHATGLM_DEFAULT_MODEL or model_params.proxyllm_backend - - import zhipuai - - zhipuai.api_key = proxy_api_key - - messages: List[ModelMessage] = params["messages"] - # TODO: Support convert_to_compatible_format config, zhipu not support system message - convert_to_compatible_format = params.get("convert_to_compatible_format", False) - - history, systems = __convert_2_zhipu_messages(messages) - res = zhipuai.model_api.sse_invoke( - model=proxyllm_backend, - prompt=history, + # convert_to_compatible_format = params.get("convert_to_compatible_format", False) + # history, systems = __convert_2_zhipu_messages(messages) + client: ZhipuLLMClient = model.proxy_llm_client + context = ModelRequestContext(stream=True, user_name=params.get("user_name")) + request = ModelRequest.build_request( + client.default_model, + messages=params["messages"], temperature=params.get("temperature"), - top_p=params.get("top_p"), - incremental=False, + context=context, + max_new_tokens=params.get("max_new_tokens"), ) - for r in res.events(): - if r.event == "add": - yield r.data + for r in client.sync_generate_stream(request): + yield r + + +class ZhipuLLMClient(ProxyLLMClient): + def __init__( + self, + model: Optional[str] = None, + api_key: Optional[str] = None, + model_alias: Optional[str] = "zhipu_proxyllm", + context_length: Optional[int] = 8192, + executor: Optional[Executor] = None, + ): + try: + import zhipuai + + except ImportError as exc: + raise ValueError( + "Could not import python package: zhipuai " + "Please install dashscope by command `pip install zhipuai" + ) from exc + if not model: + model = CHATGLM_DEFAULT_MODEL + if api_key: + zhipuai.api_key = api_key + self._model = model + self.default_model = self._model + + super().__init__( + model_names=[model, model_alias], + context_length=context_length, + executor=executor, + ) + + @classmethod + def new_client( + cls, + model_params: ProxyModelParameters, + default_executor: Optional[Executor] = None, + ) -> "ZhipuLLMClient": + return cls( + model=model_params.proxyllm_backend, + api_key=model_params.proxy_api_key, + model_alias=model_params.model_name, + context_length=model_params.max_context_size, + executor=default_executor, + ) + + def sync_generate_stream( + self, + request: ModelRequest, + message_converter: Optional[MessageConverter] = None, + ) -> Iterator[ModelOutput]: + import zhipuai + + request = self.local_covert_message(request, message_converter) + + messages = request.to_common_messages(support_system_role=False) + + model = request.model or self._model + try: + res = zhipuai.model_api.sse_invoke( + model=model, + prompt=messages, + temperature=request.temperature, + # top_p=params.get("top_p"), + incremental=False, + ) + for r in res.events(): + if r.event == "add": + yield ModelOutput(text=r.data, error_code=0) + elif r.event == "error": + yield ModelOutput(text=r.data, error_code=1) + except Exception as e: + return ModelOutput( + text=f"**LLMServer Generate Error, Please CheckErrorInfo.**: {e}", + error_code=1, + ) diff --git a/dbgpt/model/utils/chatgpt_utils.py b/dbgpt/model/utils/chatgpt_utils.py index f4d022a3a..489a5f3ce 100644 --- a/dbgpt/model/utils/chatgpt_utils.py +++ b/dbgpt/model/utils/chatgpt_utils.py @@ -3,34 +3,21 @@ import importlib.metadata as metadata import logging import os -from abc import ABC from dataclasses import dataclass from typing import ( TYPE_CHECKING, - Any, AsyncIterator, Awaitable, Callable, - Dict, - List, Optional, + Tuple, Union, ) from dbgpt._private.pydantic import model_to_json -from dbgpt.component import ComponentType -from dbgpt.core.awel import BaseOperator, TransformStreamAbsOperator -from dbgpt.core.interface.llm import ( - LLMClient, - MessageConverter, - ModelMetadata, - ModelOutput, - ModelRequest, -) +from dbgpt.core.awel import TransformStreamAbsOperator +from dbgpt.core.interface.llm import ModelOutput from dbgpt.core.operator import BaseLLM -from dbgpt.model.cluster import WorkerManagerFactory -from dbgpt.model.cluster.client import DefaultLLMClient -from dbgpt.model.utils.token_utils import ProxyTokenizerWrapper if TYPE_CHECKING: import httpx @@ -101,14 +88,14 @@ def _initialize_openai_v1(init_params: OpenAIParameters): return openai_params, api_type, api_version -def _build_openai_client(init_params: OpenAIParameters): +def _build_openai_client(init_params: OpenAIParameters) -> Tuple[str, ClientType]: import httpx openai_params, api_type, api_version = _initialize_openai_v1(init_params) if api_type == "azure": from openai import AsyncAzureOpenAI - return AsyncAzureOpenAI( + return api_type, AsyncAzureOpenAI( api_key=openai_params["api_key"], api_version=api_version, azure_endpoint=openai_params["base_url"], @@ -117,149 +104,11 @@ def _build_openai_client(init_params: OpenAIParameters): else: from openai import AsyncOpenAI - return AsyncOpenAI( + return api_type, AsyncOpenAI( **openai_params, http_client=httpx.AsyncClient(proxies=init_params.proxies) ) -class OpenAILLMClient(LLMClient): - """An implementation of LLMClient using OpenAI API. - - In order to have as few dependencies as possible, we directly use the http API. - """ - - def __init__( - self, - api_key: Optional[str] = None, - api_base: Optional[str] = None, - api_type: Optional[str] = None, - api_version: Optional[str] = None, - model: Optional[str] = "gpt-3.5-turbo", - proxies: Optional["ProxiesTypes"] = None, - timeout: Optional[int] = 240, - model_alias: Optional[str] = "chatgpt_proxyllm", - context_length: Optional[int] = 8192, - openai_client: Optional["ClientType"] = None, - openai_kwargs: Optional[Dict[str, Any]] = None, - ): - self._init_params = OpenAIParameters( - api_type=api_type, - api_base=api_base, - api_key=api_key, - api_version=api_version, - proxies=proxies, - ) - - self._model = model - self._proxies = proxies - self._timeout = timeout - self._model_alias = model_alias - self._context_length = context_length - self._client = openai_client - self._openai_kwargs = openai_kwargs or {} - self._tokenizer = ProxyTokenizerWrapper() - - @property - def client(self) -> ClientType: - if self._client is None: - self._client = _build_openai_client(init_params=self._init_params) - return self._client - - def _build_request( - self, request: ModelRequest, stream: Optional[bool] = False - ) -> Dict[str, Any]: - payload = {"model": request.model or self._model, "stream": stream} - - # Apply openai kwargs - for k, v in self._openai_kwargs.items(): - payload[k] = v - if request.temperature: - payload["temperature"] = request.temperature - if request.max_new_tokens: - payload["max_tokens"] = request.max_new_tokens - return payload - - async def generate( - self, - request: ModelRequest, - message_converter: Optional[MessageConverter] = None, - ) -> ModelOutput: - request = await self.covert_message(request, message_converter) - - messages = request.to_openai_messages() - payload = self._build_request(request) - logger.info( - f"Send request to openai, payload: {payload}\n\n messages:\n{messages}" - ) - try: - chat_completion = await self.client.chat.completions.create( - messages=messages, **payload - ) - text = chat_completion.choices[0].message.content - usage = chat_completion.usage.dict() - return ModelOutput(text=text, error_code=0, usage=usage) - except Exception as e: - return ModelOutput( - text=f"**LLMServer Generate Error, Please CheckErrorInfo.**: {e}", - error_code=1, - ) - - async def generate_stream( - self, - request: ModelRequest, - message_converter: Optional[MessageConverter] = None, - ) -> AsyncIterator[ModelOutput]: - request = await self.covert_message(request, message_converter) - messages = request.to_openai_messages() - payload = self._build_request(request, True) - logger.info( - f"Send request to openai, payload: {payload}\n\n messages:\n{messages}" - ) - try: - chat_completion = await self.client.chat.completions.create( - messages=messages, **payload - ) - text = "" - async for r in chat_completion: - if len(r.choices) == 0: - continue - if r.choices[0].delta.content is not None: - content = r.choices[0].delta.content - text += content - yield ModelOutput(text=text, error_code=0) - except Exception as e: - yield ModelOutput( - text=f"**LLMServer Generate Error, Please CheckErrorInfo.**: {e}", - error_code=1, - ) - - async def models(self) -> List[ModelMetadata]: - model_metadata = ModelMetadata( - model=self._model_alias, - context_length=await self.get_context_length(), - ) - return [model_metadata] - - async def get_context_length(self) -> int: - """Get the context length of the model. - - Returns: - int: The context length. - # TODO: This is a temporary solution. We should have a better way to get the context length. - eg. get real context length from the openai api. - """ - return self._context_length - - async def count_token(self, model: str, prompt: str) -> int: - """Count the number of tokens in a given prompt. - - Args: - model (str): The model name. - prompt (str): The prompt. - """ - return self._tokenizer.count_token(prompt, model) - - class OpenAIStreamingOutputOperator(TransformStreamAbsOperator[ModelOutput, str]): """Transform ModelOutput to openai stream format.""" diff --git a/dbgpt/model/operator/model_operator.py b/dbgpt/storage/cache/operator.py similarity index 68% rename from dbgpt/model/operator/model_operator.py rename to dbgpt/storage/cache/operator.py index 061d6fdf0..40260cb54 100644 --- a/dbgpt/model/operator/model_operator.py +++ b/dbgpt/storage/cache/operator.py @@ -1,17 +1,15 @@ import logging from typing import AsyncIterator, Dict, List, Union -from dbgpt.component import ComponentType -from dbgpt.core import ModelOutput +from dbgpt.core import ModelOutput, ModelRequest from dbgpt.core.awel import ( + BaseOperator, BranchFunc, BranchOperator, MapOperator, StreamifyAbsOperator, TransformStreamAbsOperator, ) -from dbgpt.core.awel.operator.base import BaseOperator -from dbgpt.model.cluster import WorkerManager, WorkerManagerFactory from dbgpt.storage.cache import CacheManager, LLMCacheClient, LLMCacheKey, LLMCacheValue logger = logging.getLogger(__name__) @@ -20,70 +18,7 @@ _LLM_MODEL_OUTPUT_CACHE_KEY = "llm_model_output_cache" -class ModelStreamOperator(StreamifyAbsOperator[Dict, ModelOutput]): - """Operator for streaming processing of model outputs. - - Args: - worker_manager (WorkerManager): The manager that handles worker processes for model inference. - **kwargs: Additional keyword arguments. - - Methods: - streamify: Asynchronously processes a stream of inputs, yielding model outputs. - """ - - def __init__(self, worker_manager: WorkerManager = None, **kwargs) -> None: - super().__init__(**kwargs) - self.worker_manager = worker_manager - - async def streamify(self, input_value: Dict) -> AsyncIterator[ModelOutput]: - """Process inputs as a stream and yield model outputs. - - Args: - input_value (Dict): The input value for the model. - - Returns: - AsyncIterator[ModelOutput]: An asynchronous iterator of model outputs. - """ - if not self.worker_manager: - self.worker_manager = self.system_app.get_component( - ComponentType.WORKER_MANAGER_FACTORY, WorkerManagerFactory - ).create() - async for out in self.worker_manager.generate_stream(input_value): - yield out - - -class ModelOperator(MapOperator[Dict, ModelOutput]): - """Operator for map-based processing of model outputs. - - Args: - worker_manager (WorkerManager): Manager for handling worker processes. - **kwargs: Additional keyword arguments. - - Methods: - map: Asynchronously processes a single input and returns the model output. - """ - - def __init__(self, worker_manager: WorkerManager = None, **kwargs) -> None: - super().__init__(**kwargs) - self.worker_manager = worker_manager - - async def map(self, input_value: Dict) -> ModelOutput: - """Process a single input and return the model output. - - Args: - input_value (Dict): The input value for the model. - - Returns: - ModelOutput: The output from the model. - """ - if not self.worker_manager: - self.worker_manager = self.system_app.get_component( - ComponentType.WORKER_MANAGER_FACTORY, WorkerManagerFactory - ).create() - return await self.worker_manager.generate(input_value) - - -class CachedModelStreamOperator(StreamifyAbsOperator[Dict, ModelOutput]): +class CachedModelStreamOperator(StreamifyAbsOperator[ModelRequest, ModelOutput]): """Operator for streaming processing of model outputs with caching. Args: @@ -99,11 +34,11 @@ def __init__(self, cache_manager: CacheManager, **kwargs) -> None: self._cache_manager = cache_manager self._client = LLMCacheClient(cache_manager) - async def streamify(self, input_value: Dict) -> AsyncIterator[ModelOutput]: + async def streamify(self, input_value: ModelRequest) -> AsyncIterator[ModelOutput]: """Process inputs as a stream with cache support and yield model outputs. Args: - input_value (Dict): The input value for the model. + input_value (ModelRequest): The input value for the model. Returns: AsyncIterator[ModelOutput]: An asynchronous iterator of model outputs. @@ -116,7 +51,7 @@ async def streamify(self, input_value: Dict) -> AsyncIterator[ModelOutput]: yield out -class CachedModelOperator(MapOperator[Dict, ModelOutput]): +class CachedModelOperator(MapOperator[ModelRequest, ModelOutput]): """Operator for map-based processing of model outputs with caching. Args: @@ -132,11 +67,11 @@ def __init__(self, cache_manager: CacheManager, **kwargs) -> None: self._cache_manager = cache_manager self._client = LLMCacheClient(cache_manager) - async def map(self, input_value: Dict) -> ModelOutput: + async def map(self, input_value: ModelRequest) -> ModelOutput: """Process a single input with cache support and return the model output. Args: - input_value (Dict): The input value for the model. + input_value (ModelRequest): The input value for the model. Returns: ModelOutput: The output from the model. @@ -148,7 +83,7 @@ async def map(self, input_value: Dict) -> ModelOutput: return llm_cache_value.get_value().output -class ModelCacheBranchOperator(BranchOperator[Dict, Dict]): +class ModelCacheBranchOperator(BranchOperator[ModelRequest, Dict]): """ A branch operator that decides whether to use cached data or to process data using the model. @@ -172,16 +107,18 @@ def __init__( self._model_task_name = model_task_name self._cache_task_name = cache_task_name - async def branches(self) -> Dict[BranchFunc[Dict], Union[BaseOperator, str]]: + async def branches( + self, + ) -> Dict[BranchFunc[ModelRequest], Union[BaseOperator, str]]: """Defines branch logic based on cache availability. Returns: Dict[BranchFunc[Dict], Union[BaseOperator, str]]: A dictionary mapping branch functions to task names. """ - async def check_cache_true(input_value: Dict) -> bool: + async def check_cache_true(input_value: ModelRequest) -> bool: # Check if the cache contains the result for the given input - if not input_value["model_cache_enable"]: + if input_value.context and not input_value.context.cache_enable: return False cache_dict = _parse_cache_key_dict(input_value) cache_key: LLMCacheKey = self._client.new_key(**cache_dict) @@ -190,11 +127,11 @@ async def check_cache_true(input_value: Dict) -> bool: f"cache_key: {cache_key}, hash key: {hash(cache_key)}, cache_value: {cache_value}" ) await self.current_dag_context.save_to_share_data( - _LLM_MODEL_INPUT_VALUE_KEY, cache_key + _LLM_MODEL_INPUT_VALUE_KEY, cache_key, overwrite=True ) return True if cache_value else False - async def check_cache_false(input_value: Dict): + async def check_cache_false(input_value: ModelRequest): # Inverse of check_cache_true return not await check_cache_true(input_value) @@ -275,7 +212,7 @@ async def map(self, input_value: ModelOutput) -> ModelOutput: return input_value -def _parse_cache_key_dict(input_value: Dict) -> Dict: +def _parse_cache_key_dict(input_value: ModelRequest) -> Dict: """Parses and extracts relevant fields from input to form a cache key dictionary. Args: @@ -284,17 +221,15 @@ def _parse_cache_key_dict(input_value: Dict) -> Dict: Returns: Dict: A dictionary used for generating cache keys. """ - prompt: str = input_value.get("prompt") - if prompt: - prompt = prompt.strip() + prompt: str = input_value.messages_to_string().strip() return { "prompt": prompt, - "model_name": input_value.get("model"), - "temperature": input_value.get("temperature"), - "max_new_tokens": input_value.get("max_new_tokens"), - "top_p": input_value.get("top_p", "1.0"), + "model_name": input_value.model, + "temperature": input_value.temperature, + "max_new_tokens": input_value.max_new_tokens, + # "top_p": input_value.get("top_p", "1.0"), # TODO pass model_type - "model_type": input_value.get("model_type", "huggingface"), + # "model_type": input_value.get("model_type", "huggingface"), } diff --git a/examples/awel/simple_chat_dag_example.py b/examples/awel/simple_chat_dag_example.py index 62d30756d..439a3b574 100644 --- a/examples/awel/simple_chat_dag_example.py +++ b/examples/awel/simple_chat_dag_example.py @@ -6,19 +6,18 @@ .. code-block:: shell - DBGPT_SERVER="http://127.0.0.1:5000" + DBGPT_SERVER="http://127.0.0.1:5555" + MODEL="gpt-3.5-turbo" curl -X POST $DBGPT_SERVER/api/v1/awel/trigger/examples/simple_chat \ -H "Content-Type: application/json" -d '{ - "model": "proxyllm", + "model": "'"$MODEL"'", "user_input": "hello" }' """ -from typing import Dict - from dbgpt._private.pydantic import BaseModel, Field -from dbgpt.core import ModelMessage +from dbgpt.core import ModelMessage, ModelRequest from dbgpt.core.awel import DAG, HttpTrigger, MapOperator -from dbgpt.model.operator.model_operator import ModelOperator +from dbgpt.model.operator import LLMOperator class TriggerReqBody(BaseModel): @@ -26,22 +25,14 @@ class TriggerReqBody(BaseModel): user_input: str = Field(..., description="User input") -class RequestHandleOperator(MapOperator[TriggerReqBody, Dict]): +class RequestHandleOperator(MapOperator[TriggerReqBody, ModelRequest]): def __init__(self, **kwargs): super().__init__(**kwargs) - async def map(self, input_value: TriggerReqBody) -> Dict: - hist = [] - hist.append(ModelMessage.build_human_message(input_value.user_input)) - hist = list(h.dict() for h in hist) - params = { - "prompt": input_value.user_input, - "messages": hist, - "model": input_value.model, - "echo": False, - } + async def map(self, input_value: TriggerReqBody) -> ModelRequest: + messages = [ModelMessage.build_human_message(input_value.user_input)] print(f"Receive input value: {input_value}") - return params + return ModelRequest.build_request(input_value.model, messages) with DAG("dbgpt_awel_simple_dag_example") as dag: @@ -50,10 +41,9 @@ async def map(self, input_value: TriggerReqBody) -> Dict: "/examples/simple_chat", methods="POST", request_body=TriggerReqBody ) request_handle_task = RequestHandleOperator() - model_task = ModelOperator() - # type(out) == ModelOutput + llm_task = LLMOperator(task_name="llm_task") model_parse_task = MapOperator(lambda out: out.to_dict()) - trigger >> request_handle_task >> model_task >> model_parse_task + trigger >> request_handle_task >> llm_task >> model_parse_task if __name__ == "__main__": diff --git a/examples/awel/simple_chat_history_example.py b/examples/awel/simple_chat_history_example.py index 138cf73d9..6b974ac1a 100644 --- a/examples/awel/simple_chat_history_example.py +++ b/examples/awel/simple_chat_history_example.py @@ -7,7 +7,7 @@ Call with non-streaming response. .. code-block:: shell - DBGPT_SERVER="http://127.0.0.1:5000" + DBGPT_SERVER="http://127.0.0.1:5555" MODEL="gpt-3.5-turbo" # Fist round curl -X POST $DBGPT_SERVER/api/v1/awel/trigger/examples/simple_history/multi_round/chat/completions \ diff --git a/examples/awel/simple_dag_example.py b/examples/awel/simple_dag_example.py index 64a28ca08..920c7ecdf 100644 --- a/examples/awel/simple_dag_example.py +++ b/examples/awel/simple_dag_example.py @@ -6,7 +6,8 @@ .. code-block:: shell - curl -X GET http://127.0.0.1:5000/api/v1/awel/trigger/examples/hello\?name\=zhangsan + DBGPT_SERVER="http://127.0.0.1:5555" + curl -X GET $DBGPT_SERVER/api/v1/awel/trigger/examples/hello\?name\=zhangsan """ from dbgpt._private.pydantic import BaseModel, Field From be719b0d14787e05e4c7a096e5f909e8cfec753c Mon Sep 17 00:00:00 2001 From: Fangyin Cheng Date: Sat, 13 Jan 2024 18:38:17 +0800 Subject: [PATCH 2/5] fix: Fix LLM echo bug --- dbgpt/app/scene/operator/app_operator.py | 1 + dbgpt/core/interface/llm.py | 4 +++- dbgpt/core/interface/output_parser.py | 2 +- 3 files changed, 5 insertions(+), 2 deletions(-) diff --git a/dbgpt/app/scene/operator/app_operator.py b/dbgpt/app/scene/operator/app_operator.py index ef5062b9d..c9e7f03b6 100644 --- a/dbgpt/app/scene/operator/app_operator.py +++ b/dbgpt/app/scene/operator/app_operator.py @@ -97,6 +97,7 @@ async def map(self, input_value: ChatComposerInput) -> ModelRequest: temperature=self._temperature, max_new_tokens=self._max_new_tokens, span_id=span_id, + echo=self._echo, ) return model_request diff --git a/dbgpt/core/interface/llm.py b/dbgpt/core/interface/llm.py index 21050527e..7e903c9ee 100644 --- a/dbgpt/core/interface/llm.py +++ b/dbgpt/core/interface/llm.py @@ -176,7 +176,7 @@ class ModelRequest: """The stop token ids of the model inference.""" context_len: Optional[int] = None """The context length of the model inference.""" - echo: Optional[bool] = True + echo: Optional[bool] = False """Whether to echo the input messages.""" span_id: Optional[str] = None """The span id of the model inference.""" @@ -246,6 +246,7 @@ def build_request( messages: List[ModelMessage], context: Optional[Union[ModelRequestContext, Dict[str, Any], BaseModel]] = None, stream: Optional[bool] = False, + echo: Optional[bool] = False, **kwargs, ): if not context: @@ -262,6 +263,7 @@ def build_request( model=model, messages=messages, context=context, + echo=echo, **kwargs, ) diff --git a/dbgpt/core/interface/output_parser.py b/dbgpt/core/interface/output_parser.py index 1fdac4510..b06c45dc6 100644 --- a/dbgpt/core/interface/output_parser.py +++ b/dbgpt/core/interface/output_parser.py @@ -44,7 +44,7 @@ def parse_model_stream_resp_ex(self, chunk: ResponseTye, skip_echo_len): """ TODO Multi mode output handler, rewrite this for multi model, use adapter mode. """ model_context = data.get("model_context") - has_echo = True + has_echo = False if model_context and "prompt_echo_len_char" in model_context: prompt_echo_len_char = int(model_context.get("prompt_echo_len_char", -1)) has_echo = bool(model_context.get("echo", False)) From a4a0505d0a29d74f27692d922e252f6099c2abd5 Mon Sep 17 00:00:00 2001 From: Fangyin Cheng Date: Sat, 13 Jan 2024 21:09:39 +0800 Subject: [PATCH 3/5] fix: Fix remote model serving error --- dbgpt/core/awel/runner/job_manager.py | 4 ++-- dbgpt/core/awel/runner/local_runner.py | 5 ++--- dbgpt/core/interface/llm.py | 2 +- dbgpt/model/adapter/base.py | 1 + dbgpt/model/cluster/base.py | 4 +++- dbgpt/model/cluster/worker/default_worker.py | 2 ++ 6 files changed, 11 insertions(+), 7 deletions(-) diff --git a/dbgpt/core/awel/runner/job_manager.py b/dbgpt/core/awel/runner/job_manager.py index b015f85e7..b8b7d2ebd 100644 --- a/dbgpt/core/awel/runner/job_manager.py +++ b/dbgpt/core/awel/runner/job_manager.py @@ -76,12 +76,12 @@ def _save_call_data( return id2call_data if len(root_nodes) == 1: node = root_nodes[0] - logger.info(f"Save call data to node {node.node_id}, call_data: {call_data}") + logger.debug(f"Save call data to node {node.node_id}, call_data: {call_data}") id2call_data[node.node_id] = call_data else: for node in root_nodes: node_id = node.node_id - logger.info( + logger.debug( f"Save call data to node {node.node_id}, call_data: {call_data.get(node_id)}" ) id2call_data[node_id] = call_data.get(node_id) diff --git a/dbgpt/core/awel/runner/local_runner.py b/dbgpt/core/awel/runner/local_runner.py index 6fb69619b..8eb16f574 100644 --- a/dbgpt/core/awel/runner/local_runner.py +++ b/dbgpt/core/awel/runner/local_runner.py @@ -35,9 +35,8 @@ async def execute_workflow( node_to_outputs=node_outputs, node_name_to_ids=job_manager._node_name_to_ids, ) - logger.info( - f"Begin run workflow from end operator, id: {node.node_id}, call_data: {call_data}" - ) + logger.info(f"Begin run workflow from end operator, id: {node.node_id}") + logger.debug(f"Node id {node.node_id}, call_data: {call_data}") skip_node_ids = set() system_app: SystemApp = DAGVar.get_current_system_app() diff --git a/dbgpt/core/interface/llm.py b/dbgpt/core/interface/llm.py index 7e903c9ee..50a668a61 100644 --- a/dbgpt/core/interface/llm.py +++ b/dbgpt/core/interface/llm.py @@ -208,7 +208,7 @@ def to_dict(self) -> Dict[str, Any]: map(lambda m: m if isinstance(m, dict) else m.dict(), new_reqeust.messages) ) # Skip None fields - return {k: v for k, v in asdict(new_reqeust).items() if v} + return {k: v for k, v in asdict(new_reqeust).items() if v is not None} def to_trace_metadata(self): metadata = self.to_dict() diff --git a/dbgpt/model/adapter/base.py b/dbgpt/model/adapter/base.py index 06089b1a4..ded5c793b 100644 --- a/dbgpt/model/adapter/base.py +++ b/dbgpt/model/adapter/base.py @@ -440,6 +440,7 @@ def model_adaptation( for m in messages ] params["messages"] = messages + params["string_prompt"] = ModelMessage.messages_to_string(messages) if not self.apply_conv_template(): # No need to apply conversation template, now for proxy LLM diff --git a/dbgpt/model/cluster/base.py b/dbgpt/model/cluster/base.py index f60fb0e54..8d1832b56 100644 --- a/dbgpt/model/cluster/base.py +++ b/dbgpt/model/cluster/base.py @@ -1,4 +1,4 @@ -from typing import Dict, List +from typing import Any, Dict, List from dbgpt._private.pydantic import BaseModel from dbgpt.core.interface.message import ModelMessage @@ -24,6 +24,8 @@ class PromptRequest(BaseModel): """Whether to return metrics of inference""" version: str = "v2" """Message version, default to v2""" + context: Dict[str, Any] = None + """Context information for the model""" class EmbeddingsRequest(BaseModel): diff --git a/dbgpt/model/cluster/worker/default_worker.py b/dbgpt/model/cluster/worker/default_worker.py index a28c1ed9a..f07e13d36 100644 --- a/dbgpt/model/cluster/worker/default_worker.py +++ b/dbgpt/model/cluster/worker/default_worker.py @@ -295,6 +295,8 @@ def _prepare_generate_stream(self, params: Dict, span_operation_name: str): self.model, self.model_path ) str_prompt = params.get("prompt") + if not str_prompt: + str_prompt = params.get("string_prompt") print( f"llm_adapter: {str(self.llm_adapter)}\n\nmodel prompt: \n\n{str_prompt}\n\n{stream_type}stream output:\n" ) From 80f6a5116eb1f0ed7e4bea7cc58b16ad5496635c Mon Sep 17 00:00:00 2001 From: Fangyin Cheng Date: Sat, 13 Jan 2024 22:22:38 +0800 Subject: [PATCH 4/5] chore: Add pylint for rag --- Makefile | 3 + dbgpt/app/scene/chat_knowledge/v1/chat.py | 4 - dbgpt/rag/chunk.py | 2 +- dbgpt/rag/chunk_manager.py | 2 +- dbgpt/rag/embedding/embedding_factory.py | 3 +- dbgpt/rag/embedding/embeddings.py | 22 ++--- dbgpt/rag/extractor/base.py | 2 +- dbgpt/rag/extractor/summary.py | 2 +- dbgpt/rag/graph/graph_engine.py | 5 +- dbgpt/rag/graph/graph_factory.py | 1 + dbgpt/rag/graph/graph_search.py | 7 +- dbgpt/rag/graph/index_struct.py | 3 +- dbgpt/rag/graph/kv_index.py | 1 + dbgpt/rag/graph/node.py | 2 +- dbgpt/rag/knowledge/base.py | 10 +- dbgpt/rag/knowledge/csv.py | 7 +- dbgpt/rag/knowledge/docx.py | 9 +- dbgpt/rag/knowledge/factory.py | 62 ++++++++---- dbgpt/rag/knowledge/html.py | 6 +- dbgpt/rag/knowledge/markdown.py | 6 +- dbgpt/rag/knowledge/pdf.py | 6 +- dbgpt/rag/knowledge/pptx.py | 6 +- dbgpt/rag/knowledge/string.py | 4 +- dbgpt/rag/knowledge/tests/test_csv.py | 3 +- dbgpt/rag/knowledge/tests/test_docx.py | 3 +- dbgpt/rag/knowledge/tests/test_html.py | 3 +- dbgpt/rag/knowledge/tests/test_markdown.py | 3 +- dbgpt/rag/knowledge/tests/test_pdf.py | 3 +- dbgpt/rag/knowledge/tests/test_txt.py | 3 +- dbgpt/rag/knowledge/txt.py | 6 +- dbgpt/rag/knowledge/url.py | 4 +- dbgpt/rag/operator/datasource.py | 1 + dbgpt/rag/operator/knowledge.py | 2 +- dbgpt/rag/operator/rerank.py | 2 +- dbgpt/rag/operator/rewrite.py | 2 +- dbgpt/rag/retriever/base.py | 1 + dbgpt/rag/retriever/db_schema.py | 99 +++++++++++--------- dbgpt/rag/retriever/embedding.py | 57 ++++++----- dbgpt/rag/retriever/rewrite.py | 3 +- dbgpt/rag/retriever/tests/test_db_struct.py | 5 +- dbgpt/rag/retriever/tests/test_embedding.py | 1 + dbgpt/rag/summary/db_summary_client.py | 13 +-- dbgpt/rag/summary/rdbms_db_summary.py | 3 +- dbgpt/rag/text_splitter/pre_text_splitter.py | 2 +- dbgpt/rag/text_splitter/text_splitter.py | 4 +- dbgpt/rag/text_splitter/token_splitter.py | 2 +- dbgpt/serve/rag/assembler/base.py | 2 +- dbgpt/serve/rag/assembler/db_schema.py | 6 +- dbgpt/serve/rag/assembler/embedding.py | 2 +- dbgpt/serve/rag/assembler/summary.py | 2 +- 50 files changed, 234 insertions(+), 178 deletions(-) diff --git a/Makefile b/Makefile index c1029f3e9..99edfdd41 100644 --- a/Makefile +++ b/Makefile @@ -44,10 +44,12 @@ fmt: setup ## Format Python code $(VENV_BIN)/isort dbgpt/core/ $(VENV_BIN)/isort dbgpt/datasource/ $(VENV_BIN)/isort dbgpt/model/ + $(VENV_BIN)/isort dbgpt/rag/ # TODO: $(VENV_BIN)/isort dbgpt/serve $(VENV_BIN)/isort dbgpt/serve/core/ $(VENV_BIN)/isort dbgpt/serve/agent/ $(VENV_BIN)/isort dbgpt/serve/conversation/ + $(VENV_BIN)/isort dbgpt/serve/rag/ $(VENV_BIN)/isort dbgpt/serve/utils/_template_files $(VENV_BIN)/isort dbgpt/storage/ $(VENV_BIN)/isort dbgpt/train/ @@ -68,6 +70,7 @@ fmt: setup ## Format Python code $(VENV_BIN)/blackdoc dbgpt/core/ $(VENV_BIN)/blackdoc dbgpt/datasource/ $(VENV_BIN)/blackdoc dbgpt/model/ + $(VENV_BIN)/blackdoc dbgpt/rag/ $(VENV_BIN)/blackdoc dbgpt/serve/ # TODO: $(VENV_BIN)/blackdoc dbgpt/storage/ $(VENV_BIN)/blackdoc dbgpt/train/ diff --git a/dbgpt/app/scene/chat_knowledge/v1/chat.py b/dbgpt/app/scene/chat_knowledge/v1/chat.py index 4090d9238..6f60e46e0 100644 --- a/dbgpt/app/scene/chat_knowledge/v1/chat.py +++ b/dbgpt/app/scene/chat_knowledge/v1/chat.py @@ -80,10 +80,6 @@ def __init__(self, chat_param: Dict): vector_store_config=config, ) query_rewrite = None - self.worker_manager = CFG.SYSTEM_APP.get_component( - ComponentType.WORKER_MANAGER_FACTORY, WorkerManagerFactory - ).create() - self.llm_client = DefaultLLMClient(worker_manager=self.worker_manager) if CFG.KNOWLEDGE_SEARCH_REWRITE: query_rewrite = QueryRewrite( llm_client=self.llm_client, diff --git a/dbgpt/rag/chunk.py b/dbgpt/rag/chunk.py index ddfb8cd90..75a061f4f 100644 --- a/dbgpt/rag/chunk.py +++ b/dbgpt/rag/chunk.py @@ -2,7 +2,7 @@ import uuid from typing import Any, Dict -from pydantic import Field, BaseModel +from pydantic import BaseModel, Field class Document(BaseModel): diff --git a/dbgpt/rag/chunk_manager.py b/dbgpt/rag/chunk_manager.py index 128ddca5e..d462a958d 100644 --- a/dbgpt/rag/chunk_manager.py +++ b/dbgpt/rag/chunk_manager.py @@ -1,5 +1,5 @@ from enum import Enum -from typing import Optional, List, Any +from typing import Any, List, Optional from pydantic import BaseModel, Field diff --git a/dbgpt/rag/embedding/embedding_factory.py b/dbgpt/rag/embedding/embedding_factory.py index d63ad968f..96864a1ff 100644 --- a/dbgpt/rag/embedding/embedding_factory.py +++ b/dbgpt/rag/embedding/embedding_factory.py @@ -1,6 +1,7 @@ from __future__ import annotations + from abc import ABC, abstractmethod -from typing import Any, Type, TYPE_CHECKING +from typing import TYPE_CHECKING, Any, Type from dbgpt.component import BaseComponent from dbgpt.rag.embedding.embeddings import HuggingFaceEmbeddings diff --git a/dbgpt/rag/embedding/embeddings.py b/dbgpt/rag/embedding/embeddings.py index cdc338508..76524a2a7 100644 --- a/dbgpt/rag/embedding/embeddings.py +++ b/dbgpt/rag/embedding/embeddings.py @@ -3,7 +3,7 @@ from typing import Any, Dict, List, Optional import requests -from pydantic import Field, Extra, BaseModel +from pydantic import BaseModel, Extra, Field DEFAULT_MODEL_NAME = "sentence-transformers/all-mpnet-base-v2" DEFAULT_INSTRUCT_MODEL = "hkunlp/instructor-large" @@ -54,12 +54,12 @@ class HuggingFaceEmbeddings(BaseModel, Embeddings): from .embeddings import HuggingFaceEmbeddings model_name = "sentence-transformers/all-mpnet-base-v2" - model_kwargs = {'device': 'cpu'} - encode_kwargs = {'normalize_embeddings': False} + model_kwargs = {"device": "cpu"} + encode_kwargs = {"normalize_embeddings": False} hf = HuggingFaceEmbeddings( model_name=model_name, model_kwargs=model_kwargs, - encode_kwargs=encode_kwargs + encode_kwargs=encode_kwargs, ) """ @@ -142,12 +142,12 @@ class HuggingFaceInstructEmbeddings(BaseModel, Embeddings): from langchain.embeddings import HuggingFaceInstructEmbeddings model_name = "hkunlp/instructor-large" - model_kwargs = {'device': 'cpu'} - encode_kwargs = {'normalize_embeddings': True} + model_kwargs = {"device": "cpu"} + encode_kwargs = {"normalize_embeddings": True} hf = HuggingFaceInstructEmbeddings( model_name=model_name, model_kwargs=model_kwargs, - encode_kwargs=encode_kwargs + encode_kwargs=encode_kwargs, ) """ @@ -221,12 +221,12 @@ class HuggingFaceBgeEmbeddings(BaseModel, Embeddings): from langchain.embeddings import HuggingFaceBgeEmbeddings model_name = "BAAI/bge-large-en" - model_kwargs = {'device': 'cpu'} - encode_kwargs = {'normalize_embeddings': True} + model_kwargs = {"device": "cpu"} + encode_kwargs = {"normalize_embeddings": True} hf = HuggingFaceBgeEmbeddings( model_name=model_name, model_kwargs=model_kwargs, - encode_kwargs=encode_kwargs + encode_kwargs=encode_kwargs, ) """ @@ -336,7 +336,7 @@ def embed_documents(self, texts: List[str]) -> List[List[float]]: hf_embeddings = HuggingFaceInferenceAPIEmbeddings( api_key="your_api_key", - model_name="sentence-transformers/all-MiniLM-l6-v2" + model_name="sentence-transformers/all-MiniLM-l6-v2", ) texts = ["Hello, world!", "How are you?"] hf_embeddings.embed_documents(texts) diff --git a/dbgpt/rag/extractor/base.py b/dbgpt/rag/extractor/base.py index 0ea072fe4..b9049b176 100644 --- a/dbgpt/rag/extractor/base.py +++ b/dbgpt/rag/extractor/base.py @@ -1,4 +1,4 @@ -from abc import abstractmethod, ABC +from abc import ABC, abstractmethod from typing import List from dbgpt.core import LLMClient diff --git a/dbgpt/rag/extractor/summary.py b/dbgpt/rag/extractor/summary.py index 97977d61f..c3f84d232 100644 --- a/dbgpt/rag/extractor/summary.py +++ b/dbgpt/rag/extractor/summary.py @@ -1,7 +1,7 @@ from typing import List, Optional from dbgpt._private.llm_metadata import LLMMetadata -from dbgpt.core import LLMClient, ModelRequest, ModelMessageRoleType +from dbgpt.core import LLMClient, ModelMessageRoleType, ModelRequest from dbgpt.rag.chunk import Chunk from dbgpt.rag.extractor.base import Extractor from dbgpt.util import utils diff --git a/dbgpt/rag/graph/graph_engine.py b/dbgpt/rag/graph/graph_engine.py index 50e66c4c9..b00ca6695 100644 --- a/dbgpt/rag/graph/graph_engine.py +++ b/dbgpt/rag/graph/graph_engine.py @@ -1,5 +1,5 @@ import logging -from typing import Any, Optional, Callable, Tuple, List +from typing import Any, Callable, List, Optional, Tuple from langchain.schema import Document from langchain.text_splitter import RecursiveCharacterTextSplitter @@ -87,9 +87,10 @@ def _extract_triplets(self, text: str) -> List[Tuple[str, str, str]]: def _llm_extract_triplets(self, text: str) -> List[Tuple[str, str, str]]: """Extract triplets from text by llm""" + import uuid + from dbgpt.app.scene import ChatScene from dbgpt.util.chat_util import llm_chat_response_nostream - import uuid chat_param = { "chat_session_id": uuid.uuid1(), diff --git a/dbgpt/rag/graph/graph_factory.py b/dbgpt/rag/graph/graph_factory.py index c868190e5..ebdc23dd5 100644 --- a/dbgpt/rag/graph/graph_factory.py +++ b/dbgpt/rag/graph/graph_factory.py @@ -1,4 +1,5 @@ from __future__ import annotations + from abc import ABC, abstractmethod from typing import Any, Type diff --git a/dbgpt/rag/graph/graph_search.py b/dbgpt/rag/graph/graph_search.py index 055e9aa5a..d485afc66 100644 --- a/dbgpt/rag/graph/graph_search.py +++ b/dbgpt/rag/graph/graph_search.py @@ -2,11 +2,11 @@ import os from collections import defaultdict from concurrent.futures import ThreadPoolExecutor -from typing import List, Optional, Dict, Any, Set, Callable +from typing import Any, Callable, Dict, List, Optional, Set from langchain.schema import Document -from dbgpt.rag.graph.node import BaseNode, TextNode, NodeWithScore +from dbgpt.rag.graph.node import BaseNode, NodeWithScore, TextNode from dbgpt.rag.graph.search import BaseSearch, SearchMode logger = logging.getLogger(__name__) @@ -77,9 +77,10 @@ async def _extract_subject_entities(self, query_str: str) -> Set[str]: async def _extract_entities_by_llm(self, text: str) -> Set[str]: """extract subject entities from text by llm""" + import uuid + from dbgpt.app.scene import ChatScene from dbgpt.util.chat_util import llm_chat_response_nostream - import uuid chat_param = { "chat_session_id": uuid.uuid1(), diff --git a/dbgpt/rag/graph/index_struct.py b/dbgpt/rag/graph/index_struct.py index e09d68c83..010d1dcea 100644 --- a/dbgpt/rag/graph/index_struct.py +++ b/dbgpt/rag/graph/index_struct.py @@ -11,9 +11,8 @@ from dataclasses_json import DataClassJsonMixin - from dbgpt.rag.graph.index_type import IndexStructType -from dbgpt.rag.graph.node import TextNode, BaseNode +from dbgpt.rag.graph.node import BaseNode, TextNode # TODO: legacy backport of old Node class Node = TextNode diff --git a/dbgpt/rag/graph/kv_index.py b/dbgpt/rag/graph/kv_index.py index 7b44b7d04..963b1da0e 100644 --- a/dbgpt/rag/graph/kv_index.py +++ b/dbgpt/rag/graph/kv_index.py @@ -1,4 +1,5 @@ from typing import List, Optional + from llama_index.data_structs.data_structs import IndexStruct from llama_index.storage.index_store.utils import ( index_struct_to_json, diff --git a/dbgpt/rag/graph/node.py b/dbgpt/rag/graph/node.py index aec68c36b..aef3e4c30 100644 --- a/dbgpt/rag/graph/node.py +++ b/dbgpt/rag/graph/node.py @@ -8,9 +8,9 @@ from typing import Any, Dict, List, Optional, Union from langchain.schema import Document -from dbgpt._private.pydantic import BaseModel, Field, root_validator from typing_extensions import Self +from dbgpt._private.pydantic import BaseModel, Field, root_validator DEFAULT_TEXT_NODE_TMPL = "{metadata_str}\n\n{content}" DEFAULT_METADATA_TMPL = "{key}: {value}" diff --git a/dbgpt/rag/knowledge/base.py b/dbgpt/rag/knowledge/base.py index cdaa7e3ed..af149b164 100644 --- a/dbgpt/rag/knowledge/base.py +++ b/dbgpt/rag/knowledge/base.py @@ -1,14 +1,14 @@ -from abc import abstractmethod, ABC +from abc import ABC, abstractmethod from enum import Enum -from typing import Optional, Any, List +from typing import Any, List, Optional from dbgpt.rag.chunk import Document from dbgpt.rag.text_splitter.text_splitter import ( - RecursiveCharacterTextSplitter, - MarkdownHeaderTextSplitter, - ParagraphTextSplitter, CharacterTextSplitter, + MarkdownHeaderTextSplitter, PageTextSplitter, + ParagraphTextSplitter, + RecursiveCharacterTextSplitter, SeparatorTextSplitter, ) diff --git a/dbgpt/rag/knowledge/csv.py b/dbgpt/rag/knowledge/csv.py index 24ee2ad82..ec41cd7f5 100644 --- a/dbgpt/rag/knowledge/csv.py +++ b/dbgpt/rag/knowledge/csv.py @@ -1,11 +1,12 @@ -from typing import Optional, Any, List import csv +from typing import Any, List, Optional + from dbgpt.rag.chunk import Document from dbgpt.rag.knowledge.base import ( - KnowledgeType, - Knowledge, ChunkStrategy, DocumentType, + Knowledge, + KnowledgeType, ) diff --git a/dbgpt/rag/knowledge/docx.py b/dbgpt/rag/knowledge/docx.py index df99a988d..ae1075c77 100644 --- a/dbgpt/rag/knowledge/docx.py +++ b/dbgpt/rag/knowledge/docx.py @@ -1,13 +1,14 @@ -from typing import Optional, Any, List +from typing import Any, List, Optional + +import docx from dbgpt.rag.chunk import Document from dbgpt.rag.knowledge.base import ( - KnowledgeType, - Knowledge, ChunkStrategy, DocumentType, + Knowledge, + KnowledgeType, ) -import docx class DocxKnowledge(Knowledge): diff --git a/dbgpt/rag/knowledge/factory.py b/dbgpt/rag/knowledge/factory.py index 845acbbe1..622d18b41 100644 --- a/dbgpt/rag/knowledge/factory.py +++ b/dbgpt/rag/knowledge/factory.py @@ -1,7 +1,6 @@ -from typing import Optional -from typing import List +from typing import List, Optional -from dbgpt.rag.knowledge.base import KnowledgeType, Knowledge +from dbgpt.rag.knowledge.base import Knowledge, KnowledgeType from dbgpt.rag.knowledge.string import StringKnowledge from dbgpt.rag.knowledge.url import URLKnowledge @@ -32,11 +31,21 @@ def create( Args: datasource: path of the file to convert knowledge_type: type of knowledge - Example: + + Examples: + .. code-block:: python - >>> from dbgpt.rag.knowledge.factory import KnowledgeFactory - >>> url_knowlege = KnowledgeFactory.create(datasource="https://www.baidu.com", knowledge_type=KnowledgeType.URL) - >>> doc_knowlege = KnowledgeFactory.create(datasource="path/to/document.pdf", knowledge_type=KnowledgeType.DOCUMENT) + + from dbgpt.rag.knowledge.factory import KnowledgeFactory + + url_knowlege = KnowledgeFactory.create( + datasource="https://www.baidu.com", knowledge_type=KnowledgeType.URL + ) + doc_knowlege = KnowledgeFactory.create( + datasource="path/to/document.pdf", + knowledge_type=KnowledgeType.DOCUMENT, + ) + """ match knowledge_type: case KnowledgeType.DOCUMENT: @@ -57,13 +66,22 @@ def from_file_path( knowledge_type: Optional[KnowledgeType] = KnowledgeType.DOCUMENT, ) -> Knowledge: """Create knowledge from path + Args: param file_path: path of the file to convert param knowledge_type: type of knowledge - Example: + + Examples: + .. code-block:: python - >>> from dbgpt.rag.knowledge.factory import KnowledgeFactory - >>> doc_knowlege = KnowledgeFactory.create(datasource="path/to/document.pdf", knowledge_type=KnowledgeType.DOCUMENT) + + from dbgpt.rag.knowledge.factory import KnowledgeFactory + + doc_knowlege = KnowledgeFactory.create( + datasource="path/to/document.pdf", + knowledge_type=KnowledgeType.DOCUMENT, + ) + """ factory = cls(file_path=file_path, knowledge_type=knowledge_type) return factory._select_document_knowledge( @@ -76,13 +94,21 @@ def from_url( knowledge_type: Optional[KnowledgeType] = KnowledgeType.URL, ) -> Knowledge: """Create knowledge from url + Args: param url: url of the file to convert param knowledge_type: type of knowledge - Example: + + Examples: + .. code-block:: python - >>> from dbgpt.rag.knowledge.factory import KnowledgeFactory - >>> url_knowlege = KnowledgeFactory.create(datasource="https://www.baidu.com", knowledge_type=KnowledgeType.URL) + + from dbgpt.rag.knowledge.factory import KnowledgeFactory + + url_knowlege = KnowledgeFactory.create( + datasource="https://www.baidu.com", knowledge_type=KnowledgeType.URL + ) + """ return URLKnowledge( url=url, @@ -130,14 +156,14 @@ def subclasses(cls): def _get_knowledge_subclasses() -> List[Knowledge]: """get all knowledge subclasses""" from dbgpt.rag.knowledge.base import Knowledge - from dbgpt.rag.knowledge.pdf import PDFKnowledge + from dbgpt.rag.knowledge.csv import CSVKnowledge from dbgpt.rag.knowledge.docx import DocxKnowledge + from dbgpt.rag.knowledge.html import HTMLKnowledge from dbgpt.rag.knowledge.markdown import MarkdownKnowledge - from dbgpt.rag.knowledge.csv import CSVKnowledge - from dbgpt.rag.knowledge.txt import TXTKnowledge + from dbgpt.rag.knowledge.pdf import PDFKnowledge from dbgpt.rag.knowledge.pptx import PPTXKnowledge - from dbgpt.rag.knowledge.html import HTMLKnowledge - from dbgpt.rag.knowledge.url import URLKnowledge from dbgpt.rag.knowledge.string import StringKnowledge + from dbgpt.rag.knowledge.txt import TXTKnowledge + from dbgpt.rag.knowledge.url import URLKnowledge return Knowledge.__subclasses__() diff --git a/dbgpt/rag/knowledge/html.py b/dbgpt/rag/knowledge/html.py index eb3dce643..7fa3e545f 100644 --- a/dbgpt/rag/knowledge/html.py +++ b/dbgpt/rag/knowledge/html.py @@ -1,13 +1,13 @@ -from typing import Optional, Any, List +from typing import Any, List, Optional import chardet from dbgpt.rag.chunk import Document from dbgpt.rag.knowledge.base import ( - Knowledge, - KnowledgeType, ChunkStrategy, DocumentType, + Knowledge, + KnowledgeType, ) diff --git a/dbgpt/rag/knowledge/markdown.py b/dbgpt/rag/knowledge/markdown.py index c707c893b..90270fd0f 100644 --- a/dbgpt/rag/knowledge/markdown.py +++ b/dbgpt/rag/knowledge/markdown.py @@ -1,11 +1,11 @@ -from typing import Optional, Any, List +from typing import Any, List, Optional from dbgpt.rag.chunk import Document from dbgpt.rag.knowledge.base import ( - KnowledgeType, - Knowledge, ChunkStrategy, DocumentType, + Knowledge, + KnowledgeType, ) diff --git a/dbgpt/rag/knowledge/pdf.py b/dbgpt/rag/knowledge/pdf.py index 87be838ca..d6807f5d9 100644 --- a/dbgpt/rag/knowledge/pdf.py +++ b/dbgpt/rag/knowledge/pdf.py @@ -1,11 +1,11 @@ -from typing import Optional, Any, List +from typing import Any, List, Optional from dbgpt.rag.chunk import Document from dbgpt.rag.knowledge.base import ( - Knowledge, - KnowledgeType, ChunkStrategy, DocumentType, + Knowledge, + KnowledgeType, ) diff --git a/dbgpt/rag/knowledge/pptx.py b/dbgpt/rag/knowledge/pptx.py index 431d51444..90fc337bf 100644 --- a/dbgpt/rag/knowledge/pptx.py +++ b/dbgpt/rag/knowledge/pptx.py @@ -1,11 +1,11 @@ -from typing import Optional, Any, List +from typing import Any, List, Optional from dbgpt.rag.chunk import Document from dbgpt.rag.knowledge.base import ( - Knowledge, - KnowledgeType, ChunkStrategy, DocumentType, + Knowledge, + KnowledgeType, ) diff --git a/dbgpt/rag/knowledge/string.py b/dbgpt/rag/knowledge/string.py index 3b8fd7d5a..44dc74541 100644 --- a/dbgpt/rag/knowledge/string.py +++ b/dbgpt/rag/knowledge/string.py @@ -1,7 +1,7 @@ -from typing import Optional, Any, List +from typing import Any, List, Optional from dbgpt.rag.chunk import Document -from dbgpt.rag.knowledge.base import KnowledgeType, Knowledge, ChunkStrategy +from dbgpt.rag.knowledge.base import ChunkStrategy, Knowledge, KnowledgeType class StringKnowledge(Knowledge): diff --git a/dbgpt/rag/knowledge/tests/test_csv.py b/dbgpt/rag/knowledge/tests/test_csv.py index 0f8029354..69e887cf5 100644 --- a/dbgpt/rag/knowledge/tests/test_csv.py +++ b/dbgpt/rag/knowledge/tests/test_csv.py @@ -1,6 +1,7 @@ -import pytest from unittest.mock import MagicMock, mock_open, patch +import pytest + from dbgpt.rag.knowledge.csv import CSVKnowledge MOCK_CSV_DATA = "id,name,age\n1,John Doe,30\n2,Jane Smith,25\n3,Bob Johnson,40" diff --git a/dbgpt/rag/knowledge/tests/test_docx.py b/dbgpt/rag/knowledge/tests/test_docx.py index e6f1ddc4f..855f561e0 100644 --- a/dbgpt/rag/knowledge/tests/test_docx.py +++ b/dbgpt/rag/knowledge/tests/test_docx.py @@ -1,6 +1,7 @@ -import pytest from unittest.mock import MagicMock, patch +import pytest + from dbgpt.rag.knowledge.docx import DocxKnowledge diff --git a/dbgpt/rag/knowledge/tests/test_html.py b/dbgpt/rag/knowledge/tests/test_html.py index 9cb123c5c..abf3388fe 100644 --- a/dbgpt/rag/knowledge/tests/test_html.py +++ b/dbgpt/rag/knowledge/tests/test_html.py @@ -1,6 +1,7 @@ -import pytest from unittest.mock import mock_open, patch +import pytest + from dbgpt.rag.knowledge.html import HTMLKnowledge MOCK_HTML_CONTENT = b""" diff --git a/dbgpt/rag/knowledge/tests/test_markdown.py b/dbgpt/rag/knowledge/tests/test_markdown.py index 619055e10..bcdfc4ee9 100644 --- a/dbgpt/rag/knowledge/tests/test_markdown.py +++ b/dbgpt/rag/knowledge/tests/test_markdown.py @@ -1,6 +1,7 @@ -import pytest from unittest.mock import mock_open, patch +import pytest + from dbgpt.rag.knowledge.markdown import MarkdownKnowledge MOCK_MARKDOWN_DATA = """# Header 1 diff --git a/dbgpt/rag/knowledge/tests/test_pdf.py b/dbgpt/rag/knowledge/tests/test_pdf.py index a4b130246..ea3c4be99 100644 --- a/dbgpt/rag/knowledge/tests/test_pdf.py +++ b/dbgpt/rag/knowledge/tests/test_pdf.py @@ -1,5 +1,6 @@ +from unittest.mock import MagicMock, mock_open, patch + import pytest -from unittest.mock import MagicMock, patch, mock_open from dbgpt.rag.knowledge.pdf import PDFKnowledge diff --git a/dbgpt/rag/knowledge/tests/test_txt.py b/dbgpt/rag/knowledge/tests/test_txt.py index ecdb241de..bb28e07e6 100644 --- a/dbgpt/rag/knowledge/tests/test_txt.py +++ b/dbgpt/rag/knowledge/tests/test_txt.py @@ -1,6 +1,7 @@ -import pytest from unittest.mock import mock_open, patch +import pytest + from dbgpt.rag.knowledge.txt import TXTKnowledge MOCK_TXT_CONTENT = b"Sample text content for testing.\nAnother line of text." diff --git a/dbgpt/rag/knowledge/txt.py b/dbgpt/rag/knowledge/txt.py index 57b3fcdd7..7be946133 100644 --- a/dbgpt/rag/knowledge/txt.py +++ b/dbgpt/rag/knowledge/txt.py @@ -1,13 +1,13 @@ -from typing import Optional, Any, List +from typing import Any, List, Optional import chardet from dbgpt.rag.chunk import Document from dbgpt.rag.knowledge.base import ( - Knowledge, - KnowledgeType, ChunkStrategy, DocumentType, + Knowledge, + KnowledgeType, ) diff --git a/dbgpt/rag/knowledge/url.py b/dbgpt/rag/knowledge/url.py index 01b4e4a8e..f1f4ab78c 100644 --- a/dbgpt/rag/knowledge/url.py +++ b/dbgpt/rag/knowledge/url.py @@ -1,7 +1,7 @@ -from typing import Optional, Any, List +from typing import Any, List, Optional from dbgpt.rag.chunk import Document -from dbgpt.rag.knowledge.base import KnowledgeType, Knowledge, ChunkStrategy +from dbgpt.rag.knowledge.base import ChunkStrategy, Knowledge, KnowledgeType class URLKnowledge(Knowledge): diff --git a/dbgpt/rag/operator/datasource.py b/dbgpt/rag/operator/datasource.py index c015eb33b..ea138dc1c 100644 --- a/dbgpt/rag/operator/datasource.py +++ b/dbgpt/rag/operator/datasource.py @@ -1,4 +1,5 @@ from typing import Any + from dbgpt.core.interface.retriever import RetrieverOperator from dbgpt.datasource.rdbms.base import RDBMSDatabase from dbgpt.rag.summary.rdbms_db_summary import _parse_db_summary diff --git a/dbgpt/rag/operator/knowledge.py b/dbgpt/rag/operator/knowledge.py index 01869a3a4..02de6a3e2 100644 --- a/dbgpt/rag/operator/knowledge.py +++ b/dbgpt/rag/operator/knowledge.py @@ -2,7 +2,7 @@ from dbgpt.core.awel import MapOperator from dbgpt.core.awel.task.base import IN -from dbgpt.rag.knowledge.base import KnowledgeType, Knowledge +from dbgpt.rag.knowledge.base import Knowledge, KnowledgeType from dbgpt.rag.knowledge.factory import KnowledgeFactory diff --git a/dbgpt/rag/operator/rerank.py b/dbgpt/rag/operator/rerank.py index 1641e4744..bb6485b6e 100644 --- a/dbgpt/rag/operator/rerank.py +++ b/dbgpt/rag/operator/rerank.py @@ -1,4 +1,4 @@ -from typing import Any, Optional, List +from typing import Any, List, Optional from dbgpt.core import LLMClient from dbgpt.core.awel import MapOperator diff --git a/dbgpt/rag/operator/rewrite.py b/dbgpt/rag/operator/rewrite.py index 9d63b3540..bade2677a 100644 --- a/dbgpt/rag/operator/rewrite.py +++ b/dbgpt/rag/operator/rewrite.py @@ -1,4 +1,4 @@ -from typing import Any, Optional, List +from typing import Any, List, Optional from dbgpt.core import LLMClient from dbgpt.core.awel import MapOperator diff --git a/dbgpt/rag/retriever/base.py b/dbgpt/rag/retriever/base.py index 2af4ab99e..86c8133d1 100644 --- a/dbgpt/rag/retriever/base.py +++ b/dbgpt/rag/retriever/base.py @@ -1,6 +1,7 @@ from abc import ABC, abstractmethod from enum import Enum from typing import List, Tuple + from dbgpt.rag.chunk import Chunk diff --git a/dbgpt/rag/retriever/db_schema.py b/dbgpt/rag/retriever/db_schema.py index 8e09ea85f..72fe425f0 100644 --- a/dbgpt/rag/retriever/db_schema.py +++ b/dbgpt/rag/retriever/db_schema.py @@ -1,13 +1,13 @@ from functools import reduce from typing import List, Optional -from dbgpt.rag.summary.rdbms_db_summary import _parse_db_summary -from dbgpt.util.chat_util import run_async_tasks from dbgpt.datasource.rdbms.base import RDBMSDatabase from dbgpt.rag.chunk import Chunk from dbgpt.rag.retriever.base import BaseRetriever -from dbgpt.rag.retriever.rerank import Ranker, DefaultRanker +from dbgpt.rag.retriever.rerank import DefaultRanker, Ranker +from dbgpt.rag.summary.rdbms_db_summary import _parse_db_summary from dbgpt.storage.vector_store.connector import VectorStoreConnector +from dbgpt.util.chat_util import run_async_tasks class DBSchemaRetriever(BaseRetriever): @@ -29,50 +29,59 @@ def __init__( query_rewrite (bool): query rewrite rerank (Ranker): rerank vector_store_connector (VectorStoreConnector): vector store connector - code example: - .. code-block:: python - >>> from dbgpt.datasource.rdbms.conn_sqlite import SQLiteTempConnect - >>> from dbgpt.serve.rag.assembler.db_schema import DBSchemaAssembler - >>> from dbgpt.storage.vector_store.connector import VectorStoreConnector - >>> from dbgpt.storage.vector_store.chroma_store import ChromaVectorConfig - >>> from dbgpt.rag.retriever.embedding import EmbeddingRetriever - - def _create_temporary_connection(): - connect = SQLiteTempConnect.create_temporary_db() - connect.create_temp_tables( - { - "user": { - "columns": { - "id": "INTEGER PRIMARY KEY", - "name": "TEXT", - "age": "INTEGER", - }, - "data": [ - (1, "Tom", 10), - (2, "Jerry", 16), - (3, "Jack", 18), - (4, "Alice", 20), - (5, "Bob", 22), - ], + + Examples: + + .. code-block:: python + + from dbgpt.datasource.rdbms.conn_sqlite import SQLiteTempConnect + from dbgpt.serve.rag.assembler.db_schema import DBSchemaAssembler + from dbgpt.storage.vector_store.connector import VectorStoreConnector + from dbgpt.storage.vector_store.chroma_store import ChromaVectorConfig + from dbgpt.rag.retriever.embedding import EmbeddingRetriever + + + def _create_temporary_connection(): + connect = SQLiteTempConnect.create_temporary_db() + connect.create_temp_tables( + { + "user": { + "columns": { + "id": "INTEGER PRIMARY KEY", + "name": "TEXT", + "age": "INTEGER", + }, + "data": [ + (1, "Tom", 10), + (2, "Jerry", 16), + (3, "Jack", 18), + (4, "Alice", 20), + (5, "Bob", 22), + ], + } } - } + ) + return connect + + + connection = _create_temporary_connection() + vector_store_config = ChromaVectorConfig(name="vector_store_name") + embedding_model_path = "{your_embedding_model_path}" + embedding_fn = embedding_factory.create(model_name=embedding_model_path) + vector_connector = VectorStoreConnector.from_default( + "Chroma", + vector_store_config=vector_store_config, + embedding_fn=embedding_fn, ) - return connect - connection = _create_temporary_connection() - vector_store_config = ChromaVectorConfig(name="vector_store_name") - embedding_model_path = "{your_embedding_model_path}" - embedding_fn = embedding_factory.create( - model_name=embedding_model_path - ) - vector_connector = VectorStoreConnector.from_default( - "Chroma", - vector_store_config=vector_store_config, - embedding_fn=embedding_fn - ) - # get db struct retriever - retriever = DBSchemaRetriever(top_k=3, vector_store_connector=vector_connector) - chunks = retriever.retrieve("show columns from table") - print(f"db struct rag example results:{[chunk.content for chunk in chunks]}") + # get db struct retriever + retriever = DBSchemaRetriever( + top_k=3, vector_store_connector=vector_connector + ) + chunks = retriever.retrieve("show columns from table") + print( + f"db struct rag example results:{[chunk.content for chunk in chunks]}" + ) + """ self._top_k = top_k diff --git a/dbgpt/rag/retriever/embedding.py b/dbgpt/rag/retriever/embedding.py index d7607e506..a9e24065f 100644 --- a/dbgpt/rag/retriever/embedding.py +++ b/dbgpt/rag/retriever/embedding.py @@ -1,12 +1,12 @@ from functools import reduce from typing import List, Optional -from dbgpt.rag.retriever.rewrite import QueryRewrite -from dbgpt.util.chat_util import run_async_tasks from dbgpt.rag.chunk import Chunk from dbgpt.rag.retriever.base import BaseRetriever -from dbgpt.rag.retriever.rerank import Ranker, DefaultRanker +from dbgpt.rag.retriever.rerank import DefaultRanker, Ranker +from dbgpt.rag.retriever.rewrite import QueryRewrite from dbgpt.storage.vector_store.connector import VectorStoreConnector +from dbgpt.util.chat_util import run_async_tasks class EmbeddingRetriever(BaseRetriever): @@ -25,31 +25,38 @@ def __init__( query_rewrite (Optional[QueryRewrite]): query rewrite rerank (Ranker): rerank vector_store_connector (VectorStoreConnector): vector store connector - code example: + + Examples: + .. code-block:: python - >>> from dbgpt.storage.vector_store.connector import VectorStoreConnector - >>> from dbgpt.storage.vector_store.chroma_store import ChromaVectorConfig - >>> from dbgpt.rag.retriever.embedding import EmbeddingRetriever - >>> from dbgpt.rag.embedding.embedding_factory import DefaultEmbeddingFactory - embedding_factory = DefaultEmbeddingFactory() - from dbgpt.rag.retriever.embedding import EmbeddingRetriever - from dbgpt.storage.vector_store.connector import VectorStoreConnector + from dbgpt.storage.vector_store.connector import VectorStoreConnector + from dbgpt.storage.vector_store.chroma_store import ChromaVectorConfig + from dbgpt.rag.retriever.embedding import EmbeddingRetriever + from dbgpt.rag.embedding.embedding_factory import ( + DefaultEmbeddingFactory, + ) - embedding_fn = embedding_factory.create( - model_name=EMBEDDING_MODEL_CONFIG[CFG.EMBEDDING_MODEL] - ) - vector_name = "test" - config = ChromaVectorConfig(name=vector_name, embedding_fn=embedding_fn) - vector_store_connector = VectorStoreConnector( - vector_store_type=""Chroma"", - vector_store_config=config, - ) - embedding_retriever = EmbeddingRetriever( - top_k=3, vector_store_connector=vector_store_connector - ) - chunks = embedding_retriever.retrieve("your query text") - print(f"embedding retriever results:{[chunk.content for chunk in chunks]}") + embedding_factory = DefaultEmbeddingFactory() + from dbgpt.rag.retriever.embedding import EmbeddingRetriever + from dbgpt.storage.vector_store.connector import VectorStoreConnector + + embedding_fn = embedding_factory.create( + model_name=EMBEDDING_MODEL_CONFIG[CFG.EMBEDDING_MODEL] + ) + vector_name = "test" + config = ChromaVectorConfig(name=vector_name, embedding_fn=embedding_fn) + vector_store_connector = VectorStoreConnector( + vector_store_type="Chroma", + vector_store_config=config, + ) + embedding_retriever = EmbeddingRetriever( + top_k=3, vector_store_connector=vector_store_connector + ) + chunks = embedding_retriever.retrieve("your query text") + print( + f"embedding retriever results:{[chunk.content for chunk in chunks]}" + ) """ self._top_k = top_k self._query_rewrite = query_rewrite diff --git a/dbgpt/rag/retriever/rewrite.py b/dbgpt/rag/retriever/rewrite.py index 82b460647..b85d01a87 100644 --- a/dbgpt/rag/retriever/rewrite.py +++ b/dbgpt/rag/retriever/rewrite.py @@ -1,5 +1,6 @@ from typing import List, Optional -from dbgpt.core import LLMClient, ModelMessage, ModelRequest, ModelMessageRoleType + +from dbgpt.core import LLMClient, ModelMessage, ModelMessageRoleType, ModelRequest REWRITE_PROMPT_TEMPLATE_EN = """ Based on the given context {context}, Generate {nums} search queries related to: {original_query}, Provide following comma-separated format: 'queries: '": diff --git a/dbgpt/rag/retriever/tests/test_db_struct.py b/dbgpt/rag/retriever/tests/test_db_struct.py index 842aac037..349111309 100644 --- a/dbgpt/rag/retriever/tests/test_db_struct.py +++ b/dbgpt/rag/retriever/tests/test_db_struct.py @@ -1,6 +1,7 @@ -import pytest -from unittest.mock import MagicMock, patch, AsyncMock from typing import List +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest import dbgpt from dbgpt.rag.chunk import Chunk diff --git a/dbgpt/rag/retriever/tests/test_embedding.py b/dbgpt/rag/retriever/tests/test_embedding.py index 9d8bc7be1..7c9f79dee 100644 --- a/dbgpt/rag/retriever/tests/test_embedding.py +++ b/dbgpt/rag/retriever/tests/test_embedding.py @@ -1,4 +1,5 @@ from unittest.mock import MagicMock + import pytest from dbgpt.rag.chunk import Chunk diff --git a/dbgpt/rag/summary/db_summary_client.py b/dbgpt/rag/summary/db_summary_client.py index 34947ff19..359d62cff 100644 --- a/dbgpt/rag/summary/db_summary_client.py +++ b/dbgpt/rag/summary/db_summary_client.py @@ -1,12 +1,9 @@ import logging - import traceback -from dbgpt.component import SystemApp -from dbgpt._private.config import Config -from dbgpt.configs.model_config import ( - EMBEDDING_MODEL_CONFIG, -) +from dbgpt._private.config import Config +from dbgpt.component import SystemApp +from dbgpt.configs.model_config import EMBEDDING_MODEL_CONFIG from dbgpt.rag.summary.rdbms_db_summary import RdbmsSummary logger = logging.getLogger(__name__) @@ -44,8 +41,8 @@ def db_summary_embedding(self, dbname, db_type): def get_db_summary(self, dbname, query, topk): """get user query related tables info""" - from dbgpt.storage.vector_store.connector import VectorStoreConnector from dbgpt.storage.vector_store.base import VectorStoreConfig + from dbgpt.storage.vector_store.connector import VectorStoreConnector vector_store_config = VectorStoreConfig(name=dbname + "_profile") vector_connector = VectorStoreConnector.from_default( @@ -82,8 +79,8 @@ def init_db_profile(self, db_summary_client, dbname): dbname(str): dbname """ vector_store_name = dbname + "_profile" - from dbgpt.storage.vector_store.connector import VectorStoreConnector from dbgpt.storage.vector_store.base import VectorStoreConfig + from dbgpt.storage.vector_store.connector import VectorStoreConnector vector_store_config = VectorStoreConfig(name=vector_store_name) vector_connector = VectorStoreConnector.from_default( diff --git a/dbgpt/rag/summary/rdbms_db_summary.py b/dbgpt/rag/summary/rdbms_db_summary.py index 6bec2e7c2..0b7333aa7 100644 --- a/dbgpt/rag/summary/rdbms_db_summary.py +++ b/dbgpt/rag/summary/rdbms_db_summary.py @@ -1,7 +1,8 @@ from typing import List + from dbgpt._private.config import Config -from dbgpt.rag.summary.db_summary import DBSummary from dbgpt.datasource.rdbms.base import RDBMSDatabase +from dbgpt.rag.summary.db_summary import DBSummary CFG = Config() diff --git a/dbgpt/rag/text_splitter/pre_text_splitter.py b/dbgpt/rag/text_splitter/pre_text_splitter.py index 65aa54d86..32178dafe 100644 --- a/dbgpt/rag/text_splitter/pre_text_splitter.py +++ b/dbgpt/rag/text_splitter/pre_text_splitter.py @@ -1,6 +1,6 @@ from typing import Iterable, List -from dbgpt.rag.chunk import Document, Chunk +from dbgpt.rag.chunk import Chunk, Document from dbgpt.rag.text_splitter.text_splitter import TextSplitter diff --git a/dbgpt/rag/text_splitter/text_splitter.py b/dbgpt/rag/text_splitter/text_splitter.py index 697031ef9..21b8cae02 100644 --- a/dbgpt/rag/text_splitter/text_splitter.py +++ b/dbgpt/rag/text_splitter/text_splitter.py @@ -1,7 +1,7 @@ import copy import logging import re -from abc import abstractmethod, ABC +from abc import ABC, abstractmethod from typing import ( Any, Callable, @@ -14,7 +14,7 @@ Union, ) -from dbgpt.rag.chunk import Document, Chunk +from dbgpt.rag.chunk import Chunk, Document logger = logging.getLogger(__name__) diff --git a/dbgpt/rag/text_splitter/token_splitter.py b/dbgpt/rag/text_splitter/token_splitter.py index 15605ae04..f00be63e4 100644 --- a/dbgpt/rag/text_splitter/token_splitter.py +++ b/dbgpt/rag/text_splitter/token_splitter.py @@ -4,7 +4,7 @@ from pydantic import BaseModel, Field, PrivateAttr from dbgpt.util.global_helper import globals_helper -from dbgpt.util.splitter_utils import split_by_sep, split_by_char +from dbgpt.util.splitter_utils import split_by_char, split_by_sep DEFAULT_METADATA_FORMAT_LEN = 2 DEFAULT_CHUNK_OVERLAP = 20 diff --git a/dbgpt/serve/rag/assembler/base.py b/dbgpt/serve/rag/assembler/base.py index b5944c781..1f9a1de4b 100644 --- a/dbgpt/serve/rag/assembler/base.py +++ b/dbgpt/serve/rag/assembler/base.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import Optional, Any, List +from typing import Any, List, Optional from dbgpt.rag.chunk import Chunk from dbgpt.rag.chunk_manager import ChunkManager, ChunkParameters diff --git a/dbgpt/serve/rag/assembler/db_schema.py b/dbgpt/serve/rag/assembler/db_schema.py index 2d1a98bc2..f7935dcb0 100644 --- a/dbgpt/serve/rag/assembler/db_schema.py +++ b/dbgpt/serve/rag/assembler/db_schema.py @@ -1,11 +1,11 @@ import os -from typing import Optional, Any, List +from typing import Any, List, Optional from dbgpt.datasource.rdbms.base import RDBMSDatabase from dbgpt.rag.chunk import Chunk -from dbgpt.rag.chunk_manager import ChunkParameters, ChunkManager +from dbgpt.rag.chunk_manager import ChunkManager, ChunkParameters from dbgpt.rag.embedding.embedding_factory import EmbeddingFactory -from dbgpt.rag.knowledge.base import Knowledge, ChunkStrategy +from dbgpt.rag.knowledge.base import ChunkStrategy, Knowledge from dbgpt.rag.knowledge.factory import KnowledgeFactory from dbgpt.rag.retriever.db_schema import DBSchemaRetriever from dbgpt.rag.summary.rdbms_db_summary import _parse_db_summary diff --git a/dbgpt/serve/rag/assembler/embedding.py b/dbgpt/serve/rag/assembler/embedding.py index 61f4196a6..bc536147b 100644 --- a/dbgpt/serve/rag/assembler/embedding.py +++ b/dbgpt/serve/rag/assembler/embedding.py @@ -1,5 +1,5 @@ import os -from typing import Optional, Any, List +from typing import Any, List, Optional from dbgpt.rag.chunk import Chunk from dbgpt.rag.chunk_manager import ChunkParameters diff --git a/dbgpt/serve/rag/assembler/summary.py b/dbgpt/serve/rag/assembler/summary.py index 2294cb990..927fa024e 100644 --- a/dbgpt/serve/rag/assembler/summary.py +++ b/dbgpt/serve/rag/assembler/summary.py @@ -1,5 +1,5 @@ import os -from typing import Optional, Any, List +from typing import Any, List, Optional from dbgpt.core import LLMClient from dbgpt.rag.chunk import Chunk From 82fa561d545b922479768b52eeb7a9d4de52c5a0 Mon Sep 17 00:00:00 2001 From: Fangyin Cheng Date: Sat, 13 Jan 2024 22:42:16 +0800 Subject: [PATCH 5/5] chore: Format whole dbgpt pacakge --- Makefile | 35 +--------------- dbgpt/_private/config.py | 3 +- dbgpt/_private/llm_metadata.py | 2 +- dbgpt/_private/pydantic.py | 4 +- dbgpt/serve/prompt/serve.py | 4 +- dbgpt/serve/utils/cli.py | 1 + dbgpt/storage/metadata/db_manager.py | 62 +++++++++++++++++++++------- 7 files changed, 57 insertions(+), 54 deletions(-) diff --git a/Makefile b/Makefile index 99edfdd41..57e83f42e 100644 --- a/Makefile +++ b/Makefile @@ -37,45 +37,14 @@ fmt: setup ## Format Python code # TODO: Use isort to sort Python imports. # https://github.com/PyCQA/isort # $(VENV_BIN)/isort . - $(VENV_BIN)/isort dbgpt/agent/ - $(VENV_BIN)/isort dbgpt/app/ - $(VENV_BIN)/isort dbgpt/cli/ - $(VENV_BIN)/isort dbgpt/configs/ - $(VENV_BIN)/isort dbgpt/core/ - $(VENV_BIN)/isort dbgpt/datasource/ - $(VENV_BIN)/isort dbgpt/model/ - $(VENV_BIN)/isort dbgpt/rag/ - # TODO: $(VENV_BIN)/isort dbgpt/serve - $(VENV_BIN)/isort dbgpt/serve/core/ - $(VENV_BIN)/isort dbgpt/serve/agent/ - $(VENV_BIN)/isort dbgpt/serve/conversation/ - $(VENV_BIN)/isort dbgpt/serve/rag/ - $(VENV_BIN)/isort dbgpt/serve/utils/_template_files - $(VENV_BIN)/isort dbgpt/storage/ - $(VENV_BIN)/isort dbgpt/train/ - $(VENV_BIN)/isort dbgpt/util/ - $(VENV_BIN)/isort dbgpt/vis/ - $(VENV_BIN)/isort dbgpt/__init__.py - $(VENV_BIN)/isort dbgpt/component.py + $(VENV_BIN)/isort dbgpt/ $(VENV_BIN)/isort --extend-skip="examples/notebook" examples # https://github.com/psf/black $(VENV_BIN)/black --extend-exclude="examples/notebook" . # TODO: Use blackdoc to format Python doctests. # https://blackdoc.readthedocs.io/en/latest/ # $(VENV_BIN)/blackdoc . - $(VENV_BIN)/blackdoc dbgpt/agent/ - $(VENV_BIN)/blackdoc dbgpt/app/ - $(VENV_BIN)/blackdoc dbgpt/cli/ - $(VENV_BIN)/blackdoc dbgpt/configs/ - $(VENV_BIN)/blackdoc dbgpt/core/ - $(VENV_BIN)/blackdoc dbgpt/datasource/ - $(VENV_BIN)/blackdoc dbgpt/model/ - $(VENV_BIN)/blackdoc dbgpt/rag/ - $(VENV_BIN)/blackdoc dbgpt/serve/ - # TODO: $(VENV_BIN)/blackdoc dbgpt/storage/ - $(VENV_BIN)/blackdoc dbgpt/train/ - $(VENV_BIN)/blackdoc dbgpt/util/ - $(VENV_BIN)/blackdoc dbgpt/vis/ + $(VENV_BIN)/blackdoc dbgpt $(VENV_BIN)/blackdoc examples # TODO: Type checking of Python code. # https://github.com/python/mypy diff --git a/dbgpt/_private/config.py b/dbgpt/_private/config.py index 63811e17f..74a0e6308 100644 --- a/dbgpt/_private/config.py +++ b/dbgpt/_private/config.py @@ -3,12 +3,13 @@ from __future__ import annotations import os -from typing import List, Optional, TYPE_CHECKING +from typing import TYPE_CHECKING, List, Optional from dbgpt.util.singleton import Singleton if TYPE_CHECKING: from auto_gpt_plugin_template import AutoGPTPluginTemplate + from dbgpt.component import SystemApp diff --git a/dbgpt/_private/llm_metadata.py b/dbgpt/_private/llm_metadata.py index 38f0b4454..a661c04ef 100644 --- a/dbgpt/_private/llm_metadata.py +++ b/dbgpt/_private/llm_metadata.py @@ -1,4 +1,4 @@ -from dbgpt._private.pydantic import Field, BaseModel +from dbgpt._private.pydantic import BaseModel, Field DEFAULT_CONTEXT_WINDOW = 3900 DEFAULT_NUM_OUTPUTS = 256 diff --git a/dbgpt/_private/pydantic.py b/dbgpt/_private/pydantic.py index c2f04928f..70c53b429 100644 --- a/dbgpt/_private/pydantic.py +++ b/dbgpt/_private/pydantic.py @@ -10,10 +10,10 @@ NonNegativeInt, PositiveFloat, PositiveInt, + PrivateAttr, ValidationError, root_validator, validator, - PrivateAttr, ) else: PYDANTIC_VERSION = 2 @@ -26,10 +26,10 @@ NonNegativeInt, PositiveFloat, PositiveInt, + PrivateAttr, ValidationError, root_validator, validator, - PrivateAttr, ) diff --git a/dbgpt/serve/prompt/serve.py b/dbgpt/serve/prompt/serve.py index 0fc6dd3eb..db7cc1c00 100644 --- a/dbgpt/serve/prompt/serve.py +++ b/dbgpt/serve/prompt/serve.py @@ -5,9 +5,9 @@ from dbgpt.component import SystemApp from dbgpt.core import PromptManager - -from dbgpt.storage.metadata import DatabaseManager from dbgpt.serve.core import BaseServe +from dbgpt.storage.metadata import DatabaseManager + from .api.endpoints import init_endpoints, router from .config import ( APP_NAME, diff --git a/dbgpt/serve/utils/cli.py b/dbgpt/serve/utils/cli.py index 78cd742b7..8ce4b51a5 100644 --- a/dbgpt/serve/utils/cli.py +++ b/dbgpt/serve/utils/cli.py @@ -1,4 +1,5 @@ import os + import click diff --git a/dbgpt/storage/metadata/db_manager.py b/dbgpt/storage/metadata/db_manager.py index 67ec35954..7f52939c1 100644 --- a/dbgpt/storage/metadata/db_manager.py +++ b/dbgpt/storage/metadata/db_manager.py @@ -41,14 +41,19 @@ def paginate_query( .. code-block:: python from dbgpt.storage.metadata import db, Model + + class User(Model): - __tablename__ = "user" - id = Column(Integer, primary_key=True) - name = Column(String(50)) - fullname = Column(String(50)) + __tablename__ = "user" + id = Column(Integer, primary_key=True) + name = Column(String(50)) + fullname = Column(String(50)) + with db.session() as session: - pagination = session.query(User).paginate_query(page=1, page_size=10) + pagination = session.query(User).paginate_query( + page=1, page_size=10 + ) print(pagination) @@ -100,25 +105,37 @@ class DatabaseManager: from urllib.parse import quote_plus as urlquote, quote from dbgpt.storage.metadata import DatabaseManager, create_model + db = DatabaseManager() # Use sqlite with memory storage. url = f"sqlite:///:memory:" - engine_args = {"pool_size": 10, "max_overflow": 20, "pool_timeout": 30, "pool_recycle": 3600, "pool_pre_ping": True} + engine_args = { + "pool_size": 10, + "max_overflow": 20, + "pool_timeout": 30, + "pool_recycle": 3600, + "pool_pre_ping": True, + } db.init_db(url, engine_args=engine_args) Model = create_model(db) + class User(Model): __tablename__ = "user" id = Column(Integer, primary_key=True) name = Column(String(50)) fullname = Column(String(50)) + with db.session() as session: session.add(User(name="test", fullname="test")) # db will commit the session automatically default. # session.commit() - assert session.query(User).filter(User.name == "test").first().name == "test" + assert ( + session.query(User).filter(User.name == "test").first().name + == "test" + ) # More usage: @@ -307,6 +324,7 @@ def init_default_db( >>> db.init_default_db(sqlite_path) >>> with db.session() as session: ... session.query(...) + ... Args: sqlite_path (str): The sqlite path. @@ -353,12 +371,17 @@ def build_from( from dbgpt.storage.metadata import DatabaseManager from sqlalchemy import Column, Integer, String + db = DatabaseManager.build_from("sqlite:///:memory:") + + class User(db.Model): __tablename__ = "user" id = Column(Integer, primary_key=True) name = Column(String(50)) fullname = Column(String(50)) + + db.create_all() with db.session() as session: session.add(User(name="test", fullname="test")) @@ -397,7 +420,8 @@ class User(db.Model): >>> sqlite_path = "/tmp/dbgpt.db" >>> db.init_default_db(sqlite_path) >>> with db.session() as session: - >>> session.query(...) + ... session.query(...) + ... >>> from dbgpt.storage.metadata import db, Model >>> from urllib.parse import quote_plus as urlquote, quote @@ -407,16 +431,24 @@ class User(db.Model): >>> user = "root" >>> password = "123456" >>> url = f"mysql+pymysql://{quote(user)}:{urlquote(password)}@{db_host}:{str(db_port)}/{db_name}" - >>> engine_args = {"pool_size": 10, "max_overflow": 20, "pool_timeout": 30, "pool_recycle": 3600, "pool_pre_ping": True} + >>> engine_args = { + ... "pool_size": 10, + ... "max_overflow": 20, + ... "pool_timeout": 30, + ... "pool_recycle": 3600, + ... "pool_pre_ping": True, + ... } >>> db.init_db(url, engine_args=engine_args) >>> class User(Model): - >>> __tablename__ = "user" - >>> id = Column(Integer, primary_key=True) - >>> name = Column(String(50)) - >>> fullname = Column(String(50)) + ... __tablename__ = "user" + ... id = Column(Integer, primary_key=True) + ... name = Column(String(50)) + ... fullname = Column(String(50)) + ... >>> with db.session() as session: - >>> session.add(User(name="test", fullname="test")) - >>> session.commit() + ... session.add(User(name="test", fullname="test")) + ... session.commit() + ... """