diff --git a/examples/awel/simple_chat_dag_example.py b/examples/awel/simple_chat_dag_example.py index 4da382d58..b53c2415e 100644 --- a/examples/awel/simple_chat_dag_example.py +++ b/examples/awel/simple_chat_dag_example.py @@ -1,5 +1,7 @@ """AWEL: Simple chat dag example + DB-GPT will automatically load and execute the current file after startup. + Example: .. code-block:: shell diff --git a/examples/awel/simple_dag_example.py b/examples/awel/simple_dag_example.py index 0bdf0dff7..bfc9b45b5 100644 --- a/examples/awel/simple_dag_example.py +++ b/examples/awel/simple_dag_example.py @@ -1,5 +1,7 @@ """AWEL: Simple dag example + DB-GPT will automatically load and execute the current file after startup. + Example: .. code-block:: shell diff --git a/examples/awel/simple_rag_example.py b/examples/awel/simple_rag_example.py index 78c08ac2f..c7cd934dc 100644 --- a/examples/awel/simple_rag_example.py +++ b/examples/awel/simple_rag_example.py @@ -1,5 +1,7 @@ """AWEL: Simple rag example + DB-GPT will automatically load and execute the current file after startup. + Example: .. code-block:: shell @@ -49,6 +51,7 @@ async def map(self, input_value: ConversationVo) -> ChatContext: "/examples/simple_rag", methods="POST", request_body=ConversationVo ) req_parse_task = RequestParseOperator() + # TODO should register prompt template first prompt_task = PromptManagerOperator() history_storage_task = ChatHistoryStorageOperator() history_task = ChatHistoryOperator() diff --git a/pilot/awel/trigger/base.py b/pilot/awel/trigger/base.py index 9cb5d1895..28662498f 100644 --- a/pilot/awel/trigger/base.py +++ b/pilot/awel/trigger/base.py @@ -1,15 +1,11 @@ from __future__ import annotations -from typing import TYPE_CHECKING from abc import ABC, abstractmethod -from ..operator.base import BaseOperator from ..operator.common_operator import TriggerOperator -from ..dag.base import DAGContext -from ..task.base import TaskOutput class Trigger(TriggerOperator, ABC): @abstractmethod - async def trigger(self, end_operator: "BaseOperator") -> None: + async def trigger(self) -> None: """Trigger the workflow or a specific operation in the workflow.""" diff --git a/pilot/awel/trigger/http_trigger.py b/pilot/awel/trigger/http_trigger.py index de459c066..175d4b63f 100644 --- a/pilot/awel/trigger/http_trigger.py +++ b/pilot/awel/trigger/http_trigger.py @@ -7,6 +7,7 @@ import logging from .base import Trigger +from ..dag.base import DAG from ..operator.base import BaseOperator if TYPE_CHECKING: @@ -50,46 +51,33 @@ async def trigger(self) -> None: def mount_to_router(self, router: "APIRouter") -> None: from fastapi import Depends - from fastapi.responses import StreamingResponse methods = self._methods if isinstance(self._methods, list) else [self._methods] - def create_route_function(name): + def create_route_function(name, req_body_cls: Optional[Type[BaseModel]]): async def _request_body_dependency(request: Request): return await _parse_request_body(request, self._req_body) - async def route_function(body: Any = Depends(_request_body_dependency)): - end_node = self.dag.leaf_nodes - if len(end_node) != 1: - raise ValueError("HttpTrigger just support one leaf node in dag") - end_node = end_node[0] - if not self._streaming_response: - return await end_node.call(call_data={"data": body}) - else: - headers = self._response_headers - media_type = ( - self._response_media_type - if self._response_media_type - else "text/event-stream" - ) - if not headers: - headers = { - "Content-Type": "text/event-stream", - "Cache-Control": "no-cache", - "Connection": "keep-alive", - "Transfer-Encoding": "chunked", - } - return StreamingResponse( - end_node.call_stream(call_data={"data": body}), - headers=headers, - media_type=media_type, - ) + async def route_function(body=Depends(_request_body_dependency)): + return await _trigger_dag( + body, + self.dag, + self._streaming_response, + self._response_headers, + self._response_media_type, + ) route_function.__name__ = name return route_function - function_name = f"dynamic_route_{self._endpoint.replace('/', '_')}" - dynamic_route_function = create_route_function(function_name) + function_name = f"AWEL_trigger_route_{self._endpoint.replace('/', '_')}" + request_model = ( + self._req_body + if isinstance(self._req_body, type) + and issubclass(self._req_body, BaseModel) + else None + ) + dynamic_route_function = create_route_function(function_name, request_model) logger.info( f"mount router function {dynamic_route_function}({function_name}), endpoint: {self._endpoint}, methods: {methods}" ) @@ -115,3 +103,35 @@ async def _parse_request_body( return request_body_cls(**request.query_params) else: return request + + +async def _trigger_dag( + body: Any, + dag: DAG, + streaming_response: Optional[bool] = False, + response_headers: Optional[Dict[str, str]] = None, + response_media_type: Optional[str] = None, +) -> Any: + from fastapi.responses import StreamingResponse + + end_node = dag.leaf_nodes + if len(end_node) != 1: + raise ValueError("HttpTrigger just support one leaf node in dag") + end_node = end_node[0] + if not streaming_response: + return await end_node.call(call_data={"data": body}) + else: + headers = response_headers + media_type = response_media_type if response_media_type else "text/event-stream" + if not headers: + headers = { + "Content-Type": "text/event-stream", + "Cache-Control": "no-cache", + "Connection": "keep-alive", + "Transfer-Encoding": "chunked", + } + return StreamingResponse( + end_node.call_stream(call_data={"data": body}), + headers=headers, + media_type=media_type, + ) diff --git a/setup.py b/setup.py index b74a2f116..6de96aa27 100644 --- a/setup.py +++ b/setup.py @@ -421,12 +421,6 @@ def cache_requires(): setup_spec.extras["cache"] = ["rocksdict", "msgpack"] -# def chat_scene(): -# setup_spec.extras["chat"] = [ -# "" -# ] - - def default_requires(): """ pip install "db-gpt[default]" @@ -445,6 +439,7 @@ def default_requires(): setup_spec.extras["default"] += setup_spec.extras["knowledge"] setup_spec.extras["default"] += setup_spec.extras["torch"] setup_spec.extras["default"] += setup_spec.extras["quantization"] + setup_spec.extras["default"] += setup_spec.extras["cache"] def all_requires():