Skip to content

Commit

Permalink
1. add interface: insert_graph and stream_query
Browse files Browse the repository at this point in the history
2. add store_config: summary_enabled
3. update tugraph_connect: function run
  • Loading branch information
KingSkyLi committed Aug 13, 2024
1 parent 65cf7be commit d46292d
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 15 deletions.
11 changes: 7 additions & 4 deletions dbgpt/datasource/conn_tugraph.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""TuGraph Connector."""

import json
from typing import Dict, List, cast
from typing import Dict, List, cast, Union, Generator

from .base import BaseConnector

Expand Down Expand Up @@ -88,15 +88,18 @@ def close(self):
"""Close the Neo4j driver."""
self._driver.close()

def run(self, query: str, fetch: str = "all") -> List:
def run(self, query: str, stream: bool = False) -> Union[List, Generator]:
"""Run GQL."""
with self._driver.session(database=self._graph) as session:
result = session.run(query)
return list(result)
if(stream):
for record in result:
yield record
else:
return list(result)

def get_columns(self, table_name: str, table_type: str = "vertex") -> List[Dict]:
"""Get fields about specified graph.
Args:
table_name (str): table name (graph name)
table_type (str): table type (vertex or edge)
Expand Down
83 changes: 72 additions & 11 deletions dbgpt/storage/graph_store/tugraph_store.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
"""TuGraph vector store."""
import logging
import os
from typing import List, Optional, Tuple
from typing import List, Optional, Tuple, Generator, Iterator

from dbgpt._private.pydantic import ConfigDict, Field
from dbgpt.datasource.conn_tugraph import TuGraphConnector
from dbgpt.storage.graph_store.base import GraphStoreBase, GraphStoreConfig
from dbgpt.storage.graph_store.graph import Direction, Edge, MemoryGraph, Vertex
from dbgpt.storage.graph_store.graph import Direction, Edge, MemoryGraph, Vertex, Graph

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -44,6 +44,10 @@ class TuGraphStoreConfig(GraphStoreConfig):
default="label",
description="The label of edge name, `label` by default.",
)
summary_enabled: str = Field(
default=False,
description=""
)


class TuGraphStore(GraphStoreBase):
Expand All @@ -55,6 +59,7 @@ def __init__(self, config: TuGraphStoreConfig) -> None:
self._port = int(os.getenv("TUGRAPH_PORT", 7687)) or config.port
self._username = os.getenv("TUGRAPH_USERNAME", "admin") or config.username
self._password = os.getenv("TUGRAPH_PASSWORD", "73@TuGraph") or config.password
self._summary_enabled = os.getenv("GRAPH_COMMUNITY_SUMMARY_ENABLED") or config.summary_enabled
self._node_label = (
os.getenv("TUGRAPH_VERTEX_TYPE", "entity") or config.vertex_type
)
Expand All @@ -73,7 +78,6 @@ def __init__(self, config: TuGraphStoreConfig) -> None:
db_name=config.name,
)
self.conn.create_graph(graph_name=config.name)

self._create_schema()

def _check_label(self, elem_type: str):
Expand All @@ -85,16 +89,24 @@ def _check_label(self, elem_type: str):

def _create_schema(self):
if not self._check_label("vertex"):
create_vertex_gql = (
f"CALL db.createLabel("
f"'vertex', '{self._node_label}', "
f"'id', ['id',string,false])"
)
if(self._summary_enabled):
# call function create label
pass
else:
create_vertex_gql = (
f"CALL db.createLabel("
f"'vertex', '{self._node_label}', "
f"'id', ['id',string,false])"
)
self.conn.run(create_vertex_gql)
if not self._check_label("edge"):
create_edge_gql = f"""CALL db.createLabel(
'edge', '{self._edge_label}', '[["{self._node_label}",
"{self._node_label}"]]', ["id",STRING,false])"""
if(self._summary_enabled):
# call function create label
pass
else:
create_edge_gql = f"""CALL db.createLabel(
'edge', '{self._edge_label}', '[["{self._node_label}",
"{self._node_label}"]]', ["id",STRING,false])"""
self.conn.run(create_edge_gql)

def get_triplets(self, subj: str) -> List[Tuple[str, str]]:
Expand Down Expand Up @@ -128,6 +140,22 @@ def escape_quotes(value: str) -> str:
self.conn.run(query=obj_query)
self.conn.run(query=rel_query)

def insert_graph(self, graph:Graph) -> None:
"""Add triplet."""
nodes:Iterator[Vertex] = MemoryGraph.vertices()
edges:Iterator[Edge] = MemoryGraph.edges()
node_list = []
edge_list = []
for node in nodes:
node_list.append({'id':node.vid(),'description':node.get_prop('description')})
node_query = f"""CALL db.upsertVertex("{self._node_label}", {node_list})"""
for edge in edges:
edge_list.append({'id':edge.sid(),'id':edge.tid(),'description':edge.get_prop('description')})
edge_query = f"""CALL db.upsertEdge('{self._edge_label}', '{{"type":"{self._node_label}","key":"id"}}', '{{"type":"{self._node_label}","key":"id"}}', {edge_list})"""

self.conn.run(query=node_query)
self.conn.run(query=edge_query)

def drop(self):
"""Delete Graph."""
self.conn.delete_graph(self._graph_name)
Expand Down Expand Up @@ -237,3 +265,36 @@ def _format_query_data(data):
for edge in graph["edges"]:
mg.append_edge(edge)
return mg

def stream_query(self, query: str) -> Generator[MemoryGraph, None, None]:
"""Execute a stream query."""
from neo4j import graph
for record in self.conn.run(query, stream=True):
mg = MemoryGraph()
for key in record.keys():
value = record[key]
if isinstance(value, graph.Node):
node_id = value._properties["id"]
node = Vertex(node_id)
MemoryGraph.upsert_vertex(node)
elif isinstance(value, graph.Relationship):
rel_nodes = value.nodes
prop_id = value._properties["id"]
src_id = rel_nodes[0]._properties["id"]
dst_id = rel_nodes[1]._properties["id"]
edge = Edge(src_id, dst_id, label=prop_id)
mg.append_edge(edge)
elif isinstance(value, graph.Path):
nodes = list(record["p"].nodes)
rels = list(record["p"].relationships)
formatted_path = []
for i in range(len(nodes)):
formatted_path.append(nodes[i]._properties["id"])
if i < len(rels):
formatted_path.append(rels[i]._properties["id"])
for path in formatted_path:
for i in range(0, len(path), 2):
mg.upsert_vertex(Vertex(path[i]))
if i + 2 < len(path):
mg.append_edge(Edge(path[i], path[i + 2], path[i + 1]))
yield mg

0 comments on commit d46292d

Please sign in to comment.