Skip to content

Commit

Permalink
refactor: update datasource connection interface
Browse files Browse the repository at this point in the history
  • Loading branch information
Aries-ckt committed Dec 21, 2023
1 parent 6440b33 commit 1a707a6
Show file tree
Hide file tree
Showing 4 changed files with 110 additions and 18 deletions.
30 changes: 27 additions & 3 deletions dbgpt/datasource/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

"""We need to design a base class. That other connector can Write with this"""
from abc import ABC
from typing import Iterable, List, Optional
from typing import Iterable, List, Optional, Any, Dict


class BaseConnect(ABC):
Expand Down Expand Up @@ -58,6 +58,25 @@ def get_table_comments(self, db_name):
"""
pass

def get_table_comment(self, table_name: str) -> Dict:
"""Get table comment.
Args:
table_name (str): table name
Returns:
comment: Dict, which contains text: Optional[str], eg:["text": "comment"]
"""
pass

def get_columns(self, table_name: str) -> Any:
"""Get columns.
Args:
table_name (_type_): _description_
Returns:
columns: List[Dict], which contains name: str, type: str, default_expression: str, is_in_primary_key: bool, comment: str
eg:[{'name': 'id', 'type': 'int', 'default_expression': '', 'is_in_primary_key': True, 'comment': 'id'}, ...]
"""
pass

def get_column_comments(self, db_name, table_name):
"""Get column comments.
Expand Down Expand Up @@ -107,8 +126,13 @@ def get_show_create_table(self, table_name):
"""Get the creation table sql about specified table."""
pass

def get_indexes(self, table_name):
"""Get table indexes about specified table."""
def get_indexes(self, table_name) -> List[Dict]:
"""Get table indexes about specified table.
Args:
table_name (str): table name
Returns:
indexes: List[Dict], eg:[{'name': 'idx_key', 'column_names': ['id']}]
"""
pass

@classmethod
Expand Down
36 changes: 29 additions & 7 deletions dbgpt/datasource/rdbms/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import pandas as pd
from urllib.parse import quote
from urllib.parse import quote_plus as urlquote
from typing import Any, Iterable, List, Optional
from typing import Any, Iterable, List, Optional, Dict
import sqlalchemy
from sqlalchemy import (
MetaData,
Expand Down Expand Up @@ -227,6 +227,16 @@ def get_table_info(self, table_names: Optional[List[str]] = None) -> str:
final_str = "\n\n".join(tables)
return final_str

def get_columns(self, table_name: str) -> Any:
"""Get columns.
Args:
table_name (_type_): _description_
Returns:
columns: List[Dict], which contains name: str, type: str, default_expression: str, is_in_primary_key: bool, comment: str
eg:[{'name': 'id', 'type': 'int', 'default_expression': '', 'is_in_primary_key': True, 'comment': 'id'}, ...]
"""
return self._inspector.get_columns(table_name)

def _get_sample_rows(self, table: Table) -> str:
# build the select command
command = select(table).limit(self._sample_rows_in_table_info)
Expand Down Expand Up @@ -475,12 +485,14 @@ def _extract_table_name_from_ddl(self, parsed):
return token.get_real_name()
return None

def get_indexes(self, table_name):
"""Get table indexes about specified table."""
session = self._db_sessions()
cursor = session.execute(text(f"SHOW INDEXES FROM {table_name}"))
indexes = cursor.fetchall()
return [(index[2], index[4]) for index in indexes]
def get_indexes(self, table_name) -> List[Dict]:
"""Get table indexes about specified table.
Args:
table_name:table name
Returns:
List[Dict]:eg:[{'name': 'idx_key', 'column_names': ['id']}]
"""
return self._inspector.get_indexes(table_name)

def get_show_create_table(self, table_name):
"""Get table show create table about specified table."""
Expand Down Expand Up @@ -550,6 +562,16 @@ def get_table_comments(self, db_name):
(table_comment[0], table_comment[1]) for table_comment in table_comments
]

def get_table_comment(self, table_name: str) -> Dict:
"""Get table comments.
Args:
table_name (str): table name
Returns:
comment: Dict, which contains text: Optional[str], eg:["text": "comment"]
"""
return self._inspector.get_table_comment(table_name)

def get_column_comments(self, db_name, table_name):
cursor = self.session.execute(
text(
Expand Down
56 changes: 51 additions & 5 deletions dbgpt/datasource/rdbms/conn_clickhouse.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import re
import sqlparse
import clickhouse_connect
from typing import List, Optional, Any, Iterable
from typing import List, Optional, Any, Iterable, Dict
from sqlalchemy import text
from urllib.parse import quote
from sqlalchemy.schema import CreateTable
Expand Down Expand Up @@ -60,7 +60,6 @@ def from_uri_db(
client = clickhouse_connect.get_client(
host=host,
user=user,
password=pwd,
port=port,
connect_timeout=15,
database=db_name,
Expand All @@ -84,9 +83,24 @@ def get_table_names(self):
tables = [row[0] for block in stream for row in block]
return tables

def get_indexes(self, table_name):
"""Get table indexes about specified table."""
return ""
def get_indexes(self, table_name) -> List[Dict]:
"""Get table indexes about specified table.
Args:
table_name (str): table name
Returns:
indexes: List[Dict], eg:[{'name': 'idx_key', 'column_names': ['id']}]
"""
session = self.client

_query_sql = f"""
SELECT name AS table, primary_key, from system.tables where database ='{self.client.database}' and table = '{table_name}'
"""
with session.query_row_block_stream(_query_sql) as stream:
indexes = [block for block in stream]
return [
{"name": "primary_key", "column_names": column_names.split(",")}
for table, column_names in indexes[0]
]

@property
def table_info(self) -> str:
Expand Down Expand Up @@ -117,6 +131,20 @@ def get_show_create_table(self, table_name):
ans = re.sub(r"\s*SETTINGS\s*\s*\w+\s*", " ", ans, flags=re.IGNORECASE)
return ans

def get_columns(self, table_name) -> List[Dict]:
"""Get columns.
Args:
table_name (_type_): _description_
Returns:
columns: List[Dict], which contains name: str, type: str, default_expression: str, is_in_primary_key: bool, comment: str
eg:[{'name': 'id', 'type': 'UInt64', 'default_expression': '', 'is_in_primary_key': True, 'comment': 'id'}, ...]
"""
fields = self.get_fields(table_name)
return [
{"name": name, "comment": comment, "type": column_type}
for name, column_type, _, _, comment in fields[0]
]

def get_fields(self, table_name):
"""Get column fields about specified table."""
session = self.client
Expand Down Expand Up @@ -211,6 +239,24 @@ def get_table_comments(self, db_name):
table_comments = [row for block in stream for row in block]
return table_comments

def get_table_comment(self, table_name: str) -> Dict:
"""Get table comment.
Args:
table_name (str): table name
Returns:
comment: Dict, which contains text: Optional[str], eg:["text": "comment"]
"""
session = self.client

_query_sql = f"""
SELECT table, comment FROM system.tables WHERE database = '{self.client.database}'and table = '{table_name}'""".format(
self.client.database
)

with session.query_row_block_stream(_query_sql) as stream:
table_comments = [row for block in stream for row in block]
return [{"text": comment} for table_name, comment in table_comments][0]

def get_column_comments(self, db_name, table_name):
session = self.client
_query_sql = f"""
Expand Down
6 changes: 3 additions & 3 deletions dbgpt/rag/summary/rdbms_db_summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,23 +76,23 @@ def _parse_table_summary(
table_name(column1(column1 comment),column2(column2 comment),column3(column3 comment) and index keys, and table comment: {table_comment})
"""
columns = []
for column in conn._inspector.get_columns(table_name):
for column in conn.get_columns(table_name):
if column.get("comment"):
columns.append(f"{column['name']} ({column.get('comment')})")
else:
columns.append(f"{column['name']}")

column_str = ", ".join(columns)
index_keys = []
for index_key in conn._inspector.get_indexes(table_name):
for index_key in conn.get_indexes(table_name):
key_str = ", ".join(index_key["column_names"])
index_keys.append(f"{index_key['name']}(`{key_str}`) ")
table_str = summary_template.format(table_name=table_name, columns=column_str)
if len(index_keys) > 0:
index_key_str = ", ".join(index_keys)
table_str += f", and index keys: {index_key_str}"
try:
comment = conn._inspector.get_table_comment(table_name)
comment = conn.get_table_comment(table_name)
except Exception:
comment = dict(text=None)
if comment.get("text"):
Expand Down

0 comments on commit 1a707a6

Please sign in to comment.