diff --git a/dbgpt/datasource/conn_tugraph.py b/dbgpt/datasource/conn_tugraph.py index c917365ed..363c7467f 100644 --- a/dbgpt/datasource/conn_tugraph.py +++ b/dbgpt/datasource/conn_tugraph.py @@ -35,6 +35,19 @@ def create_graph(self, graph_name: str) -> bool: return not exists + def is_exist(self, graph_name: str) -> bool: + """Check a new graph in the database if it doesn't already exist.""" + try: + with self._driver.session(database="default") as session: + graph_list = session.run("CALL dbms.graph.listGraphs()").data() + exists = any(item["graph_name"] == graph_name for item in graph_list) + except Exception as e: + raise Exception( + f"Failed to check graph exist'{graph_name}': {str(e)}" + ) from e + + return exists + def delete_graph(self, graph_name: str) -> None: """Delete a graph in the database if it exists.""" with self._driver.session(database="default") as session: diff --git a/dbgpt/storage/graph_store/base.py b/dbgpt/storage/graph_store/base.py index a3344eeea..64c94b916 100644 --- a/dbgpt/storage/graph_store/base.py +++ b/dbgpt/storage/graph_store/base.py @@ -40,3 +40,7 @@ def __init__(self, config: GraphStoreConfig): @abstractmethod def get_config(self) -> GraphStoreConfig: """Get the graph store config.""" + + def is_exist(self, name) -> bool: + """Check Graph Name is Exist.""" + raise NotImplementedError diff --git a/dbgpt/storage/graph_store/tugraph_store.py b/dbgpt/storage/graph_store/tugraph_store.py index 4838906a8..5ee0ef178 100644 --- a/dbgpt/storage/graph_store/tugraph_store.py +++ b/dbgpt/storage/graph_store/tugraph_store.py @@ -102,6 +102,10 @@ def get_config(self) -> TuGraphStoreConfig: """Get the TuGraph store config.""" return self._config + def is_exist(self, name) -> bool: + """Check Graph Name is Exist.""" + return self.conn.is_exist(name) + def _add_vertex_index(self, field_name): """Add an index to the vertex table.""" # TODO: Not used in the current implementation. diff --git a/dbgpt/storage/knowledge_graph/community_summary.py b/dbgpt/storage/knowledge_graph/community_summary.py index 806f9df54..2d756ea65 100644 --- a/dbgpt/storage/knowledge_graph/community_summary.py +++ b/dbgpt/storage/knowledge_graph/community_summary.py @@ -186,6 +186,8 @@ def get_config(self) -> BuiltinKnowledgeGraphConfig: async def aload_document(self, chunks: List[Chunk]) -> List[str]: """Extract and persist graph from the document file.""" + if not self.vector_name_exists(): + self._graph_store_apdater.create_graph(self.get_config().name) await self._aload_document_graph(chunks) await self._aload_triplet_graph(chunks) await self._community_store.build_communities( diff --git a/dbgpt/storage/knowledge_graph/knowledge_graph.py b/dbgpt/storage/knowledge_graph/knowledge_graph.py index ef2d15039..9eb158fd8 100644 --- a/dbgpt/storage/knowledge_graph/knowledge_graph.py +++ b/dbgpt/storage/knowledge_graph/knowledge_graph.py @@ -78,6 +78,8 @@ async def process_chunk(chunk: Chunk): return chunk.chunk_id # wait async tasks completed + if not self.vector_name_exists(): + self._graph_store_apdater.create_graph(self.get_config().name) tasks = [process_chunk(chunk) for chunk in chunks] loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) @@ -93,6 +95,8 @@ async def aload_document(self, chunks: List[Chunk]) -> List[str]: # type: ignor Return: List[str]: chunk ids. """ + if not self.vector_name_exists(): + self._graph_store_apdater.create_graph(self.get_config().name) for chunk in chunks: triplets = await self._triplet_extractor.extract(chunk.content) for triplet in triplets: @@ -185,3 +189,7 @@ def delete_by_ids(self, ids: str) -> List[str]: """Delete by ids.""" self._graph_store_apdater.delete_document(chunk_id=ids) return [] + + def vector_name_exists(self) -> bool: + """Whether name exists.""" + return self._graph_store_apdater.graph_store.is_exist(self._config.name)