Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: Refactor storage system #937

Merged
merged 1 commit into from
Dec 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
408 changes: 229 additions & 179 deletions assets/schema/knowledge_management.sql

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions dbgpt/_private/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,7 @@ def __init__(self) -> None:
self.LOCAL_DB_USER = os.getenv("LOCAL_DB_USER", "root")
self.LOCAL_DB_PASSWORD = os.getenv("LOCAL_DB_PASSWORD", "aa123456")
self.LOCAL_DB_POOL_SIZE = int(os.getenv("LOCAL_DB_POOL_SIZE", 10))
self.LOCAL_DB_POOL_OVERFLOW = int(os.getenv("LOCAL_DB_POOL_OVERFLOW", 20))

self.CHAT_HISTORY_STORE_TYPE = os.getenv("CHAT_HISTORY_STORE_TYPE", "db")

Expand Down
32 changes: 9 additions & 23 deletions dbgpt/agent/db/my_plugin_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,10 @@
from sqlalchemy import Column, Integer, String, DateTime, func
from sqlalchemy import UniqueConstraint

from dbgpt.storage.metadata import BaseDao
from dbgpt.storage.metadata.meta_data import (
Base,
engine,
session,
META_DATA_DATABASE,
)
from dbgpt.storage.metadata import BaseDao, Model


class MyPluginEntity(Base):
class MyPluginEntity(Model):
__tablename__ = "my_plugin"
__table_args__ = {
"mysql_charset": "utf8mb4",
Expand Down Expand Up @@ -39,16 +33,8 @@ class MyPluginEntity(Base):


class MyPluginDao(BaseDao[MyPluginEntity]):
def __init__(self):
super().__init__(
database=META_DATA_DATABASE,
orm_base=Base,
db_engine=engine,
session=session,
)

def add(self, engity: MyPluginEntity):
session = self.get_session()
session = self.get_raw_session()
my_plugin = MyPluginEntity(
tenant=engity.tenant,
user_code=engity.user_code,
Expand All @@ -68,13 +54,13 @@ def add(self, engity: MyPluginEntity):
return id

def update(self, entity: MyPluginEntity):
session = self.get_session()
session = self.get_raw_session()
updated = session.merge(entity)
session.commit()
return updated.id

def get_by_user(self, user: str) -> list[MyPluginEntity]:
session = self.get_session()
session = self.get_raw_session()
my_plugins = session.query(MyPluginEntity)
if user:
my_plugins = my_plugins.filter(MyPluginEntity.user_code == user)
Expand All @@ -83,7 +69,7 @@ def get_by_user(self, user: str) -> list[MyPluginEntity]:
return result

def get_by_user_and_plugin(self, user: str, plugin: str) -> MyPluginEntity:
session = self.get_session()
session = self.get_raw_session()
my_plugins = session.query(MyPluginEntity)
if user:
my_plugins = my_plugins.filter(MyPluginEntity.user_code == user)
Expand All @@ -93,7 +79,7 @@ def get_by_user_and_plugin(self, user: str, plugin: str) -> MyPluginEntity:
return result

def list(self, query: MyPluginEntity, page=1, page_size=20) -> list[MyPluginEntity]:
session = self.get_session()
session = self.get_raw_session()
my_plugins = session.query(MyPluginEntity)
all_count = my_plugins.count()
if query.id is not None:
Expand Down Expand Up @@ -122,7 +108,7 @@ def list(self, query: MyPluginEntity, page=1, page_size=20) -> list[MyPluginEnti
return result, total_pages, all_count

def count(self, query: MyPluginEntity):
session = self.get_session()
session = self.get_raw_session()
my_plugins = session.query(func.count(MyPluginEntity.id))
if query.id is not None:
my_plugins = my_plugins.filter(MyPluginEntity.id == query.id)
Expand All @@ -143,7 +129,7 @@ def count(self, query: MyPluginEntity):
return count

def delete(self, plugin_id: int):
session = self.get_session()
session = self.get_raw_session()
if plugin_id is None:
raise Exception("plugin_id is None")
query = MyPluginEntity(id=plugin_id)
Expand Down
32 changes: 9 additions & 23 deletions dbgpt/agent/db/plugin_hub_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,13 @@
from sqlalchemy import Column, Integer, String, Index, DateTime, func, DDL
from sqlalchemy import UniqueConstraint

from dbgpt.storage.metadata import BaseDao
from dbgpt.storage.metadata.meta_data import (
Base,
engine,
session,
META_DATA_DATABASE,
)
from dbgpt.storage.metadata import BaseDao, Model

# TODO We should consider that the production environment does not have permission to execute the DDL
char_set_sql = DDL("ALTER TABLE plugin_hub CONVERT TO CHARACTER SET utf8mb4")


class PluginHubEntity(Base):
class PluginHubEntity(Model):
__tablename__ = "plugin_hub"
__table_args__ = {
"mysql_charset": "utf8mb4",
Expand Down Expand Up @@ -43,16 +37,8 @@ class PluginHubEntity(Base):


class PluginHubDao(BaseDao[PluginHubEntity]):
def __init__(self):
super().__init__(
database=META_DATA_DATABASE,
orm_base=Base,
db_engine=engine,
session=session,
)

def add(self, engity: PluginHubEntity):
session = self.get_session()
session = self.get_raw_session()
timezone = pytz.timezone("Asia/Shanghai")
plugin_hub = PluginHubEntity(
name=engity.name,
Expand All @@ -71,7 +57,7 @@ def add(self, engity: PluginHubEntity):
return id

def update(self, entity: PluginHubEntity):
session = self.get_session()
session = self.get_raw_session()
try:
updated = session.merge(entity)
session.commit()
Expand All @@ -82,7 +68,7 @@ def update(self, entity: PluginHubEntity):
def list(
self, query: PluginHubEntity, page=1, page_size=20
) -> list[PluginHubEntity]:
session = self.get_session()
session = self.get_raw_session()
plugin_hubs = session.query(PluginHubEntity)
all_count = plugin_hubs.count()

Expand Down Expand Up @@ -111,23 +97,23 @@ def list(
return result, total_pages, all_count

def get_by_storage_url(self, storage_url):
session = self.get_session()
session = self.get_raw_session()
plugin_hubs = session.query(PluginHubEntity)
plugin_hubs = plugin_hubs.filter(PluginHubEntity.storage_url == storage_url)
result = plugin_hubs.all()
session.close()
return result

def get_by_name(self, name: str) -> PluginHubEntity:
session = self.get_session()
session = self.get_raw_session()
plugin_hubs = session.query(PluginHubEntity)
plugin_hubs = plugin_hubs.filter(PluginHubEntity.name == name)
result = plugin_hubs.first()
session.close()
return result

def count(self, query: PluginHubEntity):
session = self.get_session()
session = self.get_raw_session()
plugin_hubs = session.query(func.count(PluginHubEntity.id))
if query.id is not None:
plugin_hubs = plugin_hubs.filter(PluginHubEntity.id == query.id)
Expand All @@ -146,7 +132,7 @@ def count(self, query: PluginHubEntity):
return count

def delete(self, plugin_id: int):
session = self.get_session()
session = self.get_raw_session()
if plugin_id is None:
raise Exception("plugin_id is None")
plugin_hubs = session.query(PluginHubEntity)
Expand Down
40 changes: 15 additions & 25 deletions dbgpt/agent/hub/agent_hub.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,18 +59,12 @@ def install_plugin(self, plugin_name: str, user_name: str = None):
else:
my_plugin_entity.user_code = Default_User

with self.hub_dao.get_session() as session:
try:
if my_plugin_entity.id is None:
session.add(my_plugin_entity)
else:
session.merge(my_plugin_entity)
session.merge(plugin_entity)
session.commit()
session.close()
except Exception as e:
logger.error("install merge roll back!" + str(e))
session.rollback()
with self.hub_dao.session() as session:
if my_plugin_entity.id is None:
session.add(my_plugin_entity)
else:
session.merge(my_plugin_entity)
session.merge(plugin_entity)
except Exception as e:
logger.error("install pluguin exception!", e)
raise ValueError(f"Install Plugin {plugin_name} Faild! {str(e)}")
Expand All @@ -87,19 +81,15 @@ def uninstall_plugin(self, plugin_name, user):
my_plugin_entity = self.my_plugin_dao.get_by_user_and_plugin(user, plugin_name)
if plugin_entity is not None:
plugin_entity.installed = plugin_entity.installed - 1
with self.hub_dao.get_session() as session:
try:
my_plugin_q = session.query(MyPluginEntity).filter(
MyPluginEntity.name == plugin_name
)
if user:
my_plugin_q.filter(MyPluginEntity.user_code == user)
my_plugin_q.delete()
if plugin_entity is not None:
session.merge(plugin_entity)
session.commit()
except:
session.rollback()
with self.hub_dao.session() as session:
my_plugin_q = session.query(MyPluginEntity).filter(
MyPluginEntity.name == plugin_name
)
if user:
my_plugin_q.filter(MyPluginEntity.user_code == user)
my_plugin_q.delete()
if plugin_entity is not None:
session.merge(plugin_entity)

if plugin_entity is not None:
# delete package file if not use
Expand Down
Loading