Skip to content

Commit

Permalink
feat(core): Support simple DB query for sdk (#917)
Browse files Browse the repository at this point in the history
Co-authored-by: chengfangyin2 <chengfangyin3@jd.com>
  • Loading branch information
fangyinc and chengfangyin2 authored Dec 11, 2023
1 parent 43190ca commit cbba50a
Show file tree
Hide file tree
Showing 18 changed files with 467 additions and 74 deletions.
16 changes: 0 additions & 16 deletions dbgpt/_private/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,19 +266,3 @@ def __init__(self) -> None:
self.MODEL_CACHE_STORAGE_DISK_DIR: str = os.getenv(
"MODEL_CACHE_STORAGE_DISK_DIR"
)

def set_debug_mode(self, value: bool) -> None:
"""Set the debug mode value"""
self.debug_mode = value

def set_templature(self, value: int) -> None:
"""Set the temperature value."""
self.temperature = value

def set_speak_mode(self, value: bool) -> None:
"""Set the speak mode value."""
self.speak_mode = value

def set_last_plugin_return(self, value: bool) -> None:
"""Set the speak mode value."""
self.last_plugin_return = value
12 changes: 4 additions & 8 deletions dbgpt/app/dbgpt_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,13 @@

ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.append(ROOT_PATH)
from dbgpt._private.config import Config
from dbgpt.configs.model_config import (
LLM_MODEL_CONFIG,
EMBEDDING_MODEL_CONFIG,
LOGDIR,
ROOT_PATH,
)
from dbgpt._private.config import Config
from dbgpt.component import SystemApp

from dbgpt.app.base import (
Expand All @@ -30,7 +30,6 @@
from dbgpt.app.prompt.api import router as prompt_router
from dbgpt.app.llm_manage.api import router as llm_manage_api


from dbgpt.app.openapi.api_v1.api_v1 import router as api_v1
from dbgpt.app.openapi.base import validation_exception_handler
from dbgpt.app.openapi.api_v1.editor.api_editor_v1 import router as api_editor_route_v1
Expand Down Expand Up @@ -59,7 +58,7 @@ def swagger_monkey_patch(*args, **kwargs):
*args,
**kwargs,
swagger_js_url="https://cdn.bootcdn.net/ajax/libs/swagger-ui/4.10.3/swagger-ui-bundle.js",
swagger_css_url="https://cdn.bootcdn.net/ajax/libs/swagger-ui/4.10.3/swagger-ui.css"
swagger_css_url="https://cdn.bootcdn.net/ajax/libs/swagger-ui/4.10.3/swagger-ui.css",
)


Expand All @@ -79,13 +78,11 @@ def swagger_monkey_patch(*args, **kwargs):
allow_headers=["*"],
)


app.include_router(api_v1, prefix="/api", tags=["Chat"])
app.include_router(api_editor_route_v1, prefix="/api", tags=["Editor"])
app.include_router(llm_manage_api, prefix="/api", tags=["LLM Manage"])
app.include_router(api_fb_v1, prefix="/api", tags=["FeedBack"])


app.include_router(knowledge_router, tags=["Knowledge"])
app.include_router(prompt_router, tags=["Prompt"])

Expand Down Expand Up @@ -133,7 +130,8 @@ def initialize_app(param: WebServerParameters = None, args: List[str] = None):

# Before start
system_app.before_start()

model_name = param.model_name or CFG.LLM_MODEL
param.model_name = model_name
print(param)

embedding_model_name = CFG.EMBEDDING_MODEL
Expand All @@ -143,8 +141,6 @@ def initialize_app(param: WebServerParameters = None, args: List[str] = None):
model_start_listener = _create_model_start_listener(system_app)
initialize_components(param, system_app, embedding_model_name, embedding_model_path)

model_name = param.model_name or CFG.LLM_MODEL

model_path = CFG.LLM_MODEL_PATH or LLM_MODEL_CONFIG.get(model_name)
if not param.light:
print("Model Unified Deployment Mode!")
Expand Down
1 change: 1 addition & 0 deletions dbgpt/app/scene/chat_db/auto_execute/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ async def generate_input_values(self) -> Dict:
except ImportError:
raise ValueError("Could not import DBSummaryClient. ")
client = DBSummaryClient(system_app=CFG.SYSTEM_APP)
table_infos = None
try:
with root_tracer.start_span("ChatWithDbAutoExecute.get_db_summary"):
table_infos = await blocking_func_to_async(
Expand Down
3 changes: 2 additions & 1 deletion dbgpt/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
OnceConversation,
)
from dbgpt.core.interface.prompt import PromptTemplate, PromptTemplateOperator
from dbgpt.core.interface.output_parser import BaseOutputParser
from dbgpt.core.interface.output_parser import BaseOutputParser, SQLOutputParser
from dbgpt.core.interface.serialization import Serializable, Serializer
from dbgpt.core.interface.cache import (
CacheKey,
Expand All @@ -33,6 +33,7 @@
"PromptTemplate",
"PromptTemplateOperator",
"BaseOutputParser",
"SQLOutputParser",
"Serializable",
"Serializer",
"CacheKey",
Expand Down
2 changes: 1 addition & 1 deletion dbgpt/core/awel/task/task_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def new_output(self) -> TaskOutput[T]:

@property
def is_empty(self) -> bool:
return not self._data
return self._data is None

async def _apply_func(self, func) -> Any:
if asyncio.iscoroutinefunction(func):
Expand Down
10 changes: 10 additions & 0 deletions dbgpt/core/interface/output_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,3 +251,13 @@ def _parse_model_response(response: ResponseTye):
else:
raise ValueError(f"Unsupported response type {type(response)}")
return resp_obj_ex


class SQLOutputParser(BaseOutputParser):
def __init__(self, is_stream_out: bool = False, **kwargs):
super().__init__(is_stream_out=is_stream_out, **kwargs)

def parse_model_nostream_resp(self, response: ResponseTye, sep: str):
model_out_text = super().parse_model_nostream_resp(response, sep)
clean_str = super().parse_prompt_response(model_out_text)
return json.loads(clean_str, strict=True)
23 changes: 23 additions & 0 deletions dbgpt/core/interface/retriever.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from abc import abstractmethod
from dbgpt.core.awel import MapOperator
from dbgpt.core.awel.task.base import IN, OUT


class RetrieverOperator(MapOperator[IN, OUT]):
"""The Abstract Retriever Operator."""

async def map(self, input_value: IN) -> OUT:
"""Map input value to output value.
Args:
input_value (IN): The input value.
Returns:
OUT: The output value.
"""
# The retrieve function is blocking, so we need to wrap it in a blocking_func_to_async.
return await self.blocking_func_to_async(self.retrieve, input_value)

@abstractmethod
def retrieve(self, input_value: IN) -> OUT:
"""Retrieve data for input value."""
70 changes: 58 additions & 12 deletions dbgpt/datasource/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,58 +2,104 @@
# -*- coding:utf-8 -*-

"""We need to design a base class. That other connector can Write with this"""
from abc import ABC, abstractmethod
from typing import Any, Iterable, List, Optional
from abc import ABC
from typing import Iterable, List, Optional


class BaseConnect(ABC):
def get_connect(self, db_name: str):
pass

def get_table_names(self) -> Iterable[str]:
"""Get all table names"""
pass

def get_table_info(self, table_names: Optional[List[str]] = None) -> str:
"""Get table info about specified table.
Returns:
str: Table information joined by '\n\n'
"""
pass

def get_index_info(self, table_names: Optional[List[str]] = None) -> str:
"""Get index info about specified table.
Args:
table_names (Optional[List[str]]): table names
"""
pass

def get_example_data(self, table: str, count: int = 3):
"""Get example data about specified table.
Not used now.
Args:
table (str): table name
count (int): example data count
"""
pass

def get_database_list(self):
def get_database_list(self) -> List[str]:
"""Get database list.
Returns:
List[str]: database list
"""
pass

def get_database_names(self):
"""Get database names."""
pass

def get_table_comments(self, db_name):
"""Get table comments.
Args:
db_name (str): database name
"""
pass

def run(self, session, command: str, fetch: str = "all") -> List:
def run(self, command: str, fetch: str = "all") -> List:
"""Execute sql command.
Args:
command (str): sql command
fetch (str): fetch type
"""
pass

def run_to_df(self, command: str, fetch: str = "all"):
"""Execute sql command and return dataframe."""
pass

def get_users(self):
pass
"""Get user info."""
return []

def get_grants(self):
pass
"""Get grant info."""
return []

def get_collation(self):
pass
"""Get collation."""
return None

def get_charset(self):
pass
def get_charset(self) -> str:
"""Get character_set of current database."""
return "utf-8"

def get_fields(self, table_name):
"""Get column fields about specified table."""
pass

def get_show_create_table(self, table_name):
"""Get the creation table sql about specified table."""
pass

def get_indexes(self, table_name):
"""Get table indexes about specified table."""
pass

@classmethod
def is_normal_type(cls) -> bool:
"""Return whether the connector is a normal type."""
return True
2 changes: 1 addition & 1 deletion dbgpt/datasource/db_conn_info.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from dbgpt._private.pydantic import BaseModel, Field
from dbgpt._private.pydantic import BaseModel


class DBConfig(BaseModel):
Expand Down
7 changes: 4 additions & 3 deletions dbgpt/datasource/manages/connection_manager.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from typing import List, Type
from dbgpt.datasource import ConnectConfigDao
from dbgpt.storage.schema import DBType
from dbgpt.component import SystemApp, ComponentType
Expand All @@ -21,7 +22,7 @@
class ConnectManager:
"""db connect manager"""

def get_all_subclasses(self, cls):
def get_all_subclasses(self, cls: Type[BaseConnect]) -> List[Type[BaseConnect]]:
subclasses = cls.__subclasses__()
for subclass in subclasses:
subclasses += self.get_all_subclasses(subclass)
Expand All @@ -31,15 +32,15 @@ def get_all_completed_types(self):
chat_classes = self.get_all_subclasses(BaseConnect)
support_types = []
for cls in chat_classes:
if cls.db_type:
if cls.db_type and cls.is_normal_type():
support_types.append(DBType.of_db_type(cls.db_type))
return support_types

def get_cls_by_dbtype(self, db_type):
chat_classes = self.get_all_subclasses(BaseConnect)
result = None
for cls in chat_classes:
if cls.db_type == db_type:
if cls.db_type == db_type and cls.is_normal_type():
result = cls
if not result:
raise ValueError("Unsupported Db Type!" + db_type)
Expand Down
Empty file.
16 changes: 16 additions & 0 deletions dbgpt/datasource/operator/datasource_operator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from typing import Any
from dbgpt.core.awel import MapOperator
from dbgpt.core.awel.task.base import IN, OUT
from dbgpt.datasource.base import BaseConnect


class DatasourceOperator(MapOperator[str, Any]):
def __init__(self, connection: BaseConnect, **kwargs):
super().__init__(**kwargs)
self._connection = connection

async def map(self, input_value: IN) -> OUT:
return await self.blocking_func_to_async(self.query, input_value)

def query(self, input_value: str) -> Any:
return self._connection.run_to_df(input_value)
Loading

0 comments on commit cbba50a

Please sign in to comment.