Skip to content

Commit

Permalink
feat(awel): Modify AWEL http trigger route function
Browse files Browse the repository at this point in the history
  • Loading branch information
fangyinc committed Nov 22, 2023
1 parent 2ce7751 commit 2274b37
Show file tree
Hide file tree
Showing 6 changed files with 59 additions and 41 deletions.
2 changes: 2 additions & 0 deletions examples/awel/simple_chat_dag_example.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""AWEL: Simple chat dag example
DB-GPT will automatically load and execute the current file after startup.
Example:
.. code-block:: shell
Expand Down
2 changes: 2 additions & 0 deletions examples/awel/simple_dag_example.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""AWEL: Simple dag example
DB-GPT will automatically load and execute the current file after startup.
Example:
.. code-block:: shell
Expand Down
3 changes: 3 additions & 0 deletions examples/awel/simple_rag_example.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""AWEL: Simple rag example
DB-GPT will automatically load and execute the current file after startup.
Example:
.. code-block:: shell
Expand Down Expand Up @@ -49,6 +51,7 @@ async def map(self, input_value: ConversationVo) -> ChatContext:
"/examples/simple_rag", methods="POST", request_body=ConversationVo
)
req_parse_task = RequestParseOperator()
# TODO should register prompt template first
prompt_task = PromptManagerOperator()
history_storage_task = ChatHistoryStorageOperator()
history_task = ChatHistoryOperator()
Expand Down
6 changes: 1 addition & 5 deletions pilot/awel/trigger/base.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,11 @@
from __future__ import annotations

from typing import TYPE_CHECKING
from abc import ABC, abstractmethod

from ..operator.base import BaseOperator
from ..operator.common_operator import TriggerOperator
from ..dag.base import DAGContext
from ..task.base import TaskOutput


class Trigger(TriggerOperator, ABC):
@abstractmethod
async def trigger(self, end_operator: "BaseOperator") -> None:
async def trigger(self) -> None:
"""Trigger the workflow or a specific operation in the workflow."""
80 changes: 50 additions & 30 deletions pilot/awel/trigger/http_trigger.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import logging

from .base import Trigger
from ..dag.base import DAG
from ..operator.base import BaseOperator

if TYPE_CHECKING:
Expand Down Expand Up @@ -50,46 +51,33 @@ async def trigger(self) -> None:

def mount_to_router(self, router: "APIRouter") -> None:
from fastapi import Depends
from fastapi.responses import StreamingResponse

methods = self._methods if isinstance(self._methods, list) else [self._methods]

def create_route_function(name):
def create_route_function(name, req_body_cls: Optional[Type[BaseModel]]):
async def _request_body_dependency(request: Request):
return await _parse_request_body(request, self._req_body)

async def route_function(body: Any = Depends(_request_body_dependency)):
end_node = self.dag.leaf_nodes
if len(end_node) != 1:
raise ValueError("HttpTrigger just support one leaf node in dag")
end_node = end_node[0]
if not self._streaming_response:
return await end_node.call(call_data={"data": body})
else:
headers = self._response_headers
media_type = (
self._response_media_type
if self._response_media_type
else "text/event-stream"
)
if not headers:
headers = {
"Content-Type": "text/event-stream",
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"Transfer-Encoding": "chunked",
}
return StreamingResponse(
end_node.call_stream(call_data={"data": body}),
headers=headers,
media_type=media_type,
)
async def route_function(body=Depends(_request_body_dependency)):
return await _trigger_dag(
body,
self.dag,
self._streaming_response,
self._response_headers,
self._response_media_type,
)

route_function.__name__ = name
return route_function

function_name = f"dynamic_route_{self._endpoint.replace('/', '_')}"
dynamic_route_function = create_route_function(function_name)
function_name = f"AWEL_trigger_route_{self._endpoint.replace('/', '_')}"
request_model = (
self._req_body
if isinstance(self._req_body, type)
and issubclass(self._req_body, BaseModel)
else None
)
dynamic_route_function = create_route_function(function_name, request_model)
logger.info(
f"mount router function {dynamic_route_function}({function_name}), endpoint: {self._endpoint}, methods: {methods}"
)
Expand All @@ -115,3 +103,35 @@ async def _parse_request_body(
return request_body_cls(**request.query_params)
else:
return request


async def _trigger_dag(
body: Any,
dag: DAG,
streaming_response: Optional[bool] = False,
response_headers: Optional[Dict[str, str]] = None,
response_media_type: Optional[str] = None,
) -> Any:
from fastapi.responses import StreamingResponse

end_node = dag.leaf_nodes
if len(end_node) != 1:
raise ValueError("HttpTrigger just support one leaf node in dag")
end_node = end_node[0]
if not streaming_response:
return await end_node.call(call_data={"data": body})
else:
headers = response_headers
media_type = response_media_type if response_media_type else "text/event-stream"
if not headers:
headers = {
"Content-Type": "text/event-stream",
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"Transfer-Encoding": "chunked",
}
return StreamingResponse(
end_node.call_stream(call_data={"data": body}),
headers=headers,
media_type=media_type,
)
7 changes: 1 addition & 6 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,12 +421,6 @@ def cache_requires():
setup_spec.extras["cache"] = ["rocksdict", "msgpack"]


# def chat_scene():
# setup_spec.extras["chat"] = [
# ""
# ]


def default_requires():
"""
pip install "db-gpt[default]"
Expand All @@ -445,6 +439,7 @@ def default_requires():
setup_spec.extras["default"] += setup_spec.extras["knowledge"]
setup_spec.extras["default"] += setup_spec.extras["torch"]
setup_spec.extras["default"] += setup_spec.extras["quantization"]
setup_spec.extras["default"] += setup_spec.extras["cache"]


def all_requires():
Expand Down

0 comments on commit 2274b37

Please sign in to comment.