Skip to content

Commit

Permalink
feat: Enhance the triplets extraction in the knowledge graph by the b…
Browse files Browse the repository at this point in the history
…atch size (#2091)
  • Loading branch information
Appointat authored Nov 5, 2024
1 parent b4ce217 commit 25d47ce
Show file tree
Hide file tree
Showing 10 changed files with 362 additions and 244 deletions.
1 change: 1 addition & 0 deletions .env.template
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,7 @@ TRIPLET_GRAPH_ENABLED=True # enable the graph search for triplets
DOCUMENT_GRAPH_ENABLED=True # enable the graph search for documents and chunks

KNOWLEDGE_GRAPH_CHUNK_SEARCH_TOP_SIZE=5 # the top size of knowledge graph search for chunks
KNOWLEDGE_GRAPH_EXTRACTION_BATCH_SIZE=20 # the batch size of triplet extraction from the text

### Chroma vector db config
#CHROMA_PERSIST_PATH=/root/DB-GPT/pilot/data
Expand Down
10 changes: 10 additions & 0 deletions dbgpt/rag/transformer/base.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Transformer base class."""

import logging
from abc import ABC, abstractmethod
from typing import List, Optional
Expand Down Expand Up @@ -37,6 +38,15 @@ class ExtractorBase(TransformerBase, ABC):
async def extract(self, text: str, limit: Optional[int] = None) -> List:
"""Extract results from text."""

@abstractmethod
async def batch_extract(
self,
texts: List[str],
batch_size: int = 1,
limit: Optional[int] = None,
) -> List:
"""Batch extract results from texts."""


class TranslatorBase(TransformerBase, ABC):
"""Translator base class."""
98 changes: 80 additions & 18 deletions dbgpt/rag/transformer/graph_extractor.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
"""GraphExtractor class."""

import asyncio
import logging
import re
from typing import List, Optional
from typing import Dict, List, Optional

from dbgpt.core import Chunk, LLMClient
from dbgpt.rag.transformer.llm_extractor import LLMExtractor
Expand All @@ -23,35 +24,96 @@ def __init__(
self._chunk_history = chunk_history

config = self._chunk_history.get_config()

self._vector_space = config.name
self._max_chunks_once_load = config.max_chunks_once_load
self._max_threads = config.max_threads
self._topk = config.topk
self._score_threshold = config.score_threshold

async def extract(self, text: str, limit: Optional[int] = None) -> List:
"""Load similar chunks."""
# load similar chunks
chunks = await self._chunk_history.asimilar_search_with_scores(
text, self._topk, self._score_threshold
)
history = [
f"Section {i + 1}:\n{chunk.content}" for i, chunk in enumerate(chunks)
]
context = "\n".join(history) if history else ""

try:
# extract with chunk history
return await super()._extract(text, context, limit)

finally:
# save chunk to history
async def aload_chunk_context(self, texts: List[str]) -> Dict[str, str]:
"""Load chunk context."""
text_context_map: Dict[str, str] = {}

for text in texts:
# Load similar chunks
chunks = await self._chunk_history.asimilar_search_with_scores(
text, self._topk, self._score_threshold
)
history = [
f"Section {i + 1}:\n{chunk.content}" for i, chunk in enumerate(chunks)
]

# Save chunk to history
await self._chunk_history.aload_document_with_limit(
[Chunk(content=text, metadata={"relevant_cnt": len(history)})],
self._max_chunks_once_load,
self._max_threads,
)

# Save chunk context to map
context = "\n".join(history) if history else ""
text_context_map[text] = context
return text_context_map

async def extract(self, text: str, limit: Optional[int] = None) -> List:
"""Extract graphs from text.
Suggestion: to extract triplets in batches, call `batch_extract`.
"""
# Load similar chunks
text_context_map = await self.aload_chunk_context([text])
context = text_context_map[text]

# Extract with chunk history
return await super()._extract(text, context, limit)

async def batch_extract(
self,
texts: List[str],
batch_size: int = 1,
limit: Optional[int] = None,
) -> List[List[Graph]]:
"""Extract graphs from chunks in batches.
Returns list of graphs in same order as input texts (text <-> graphs).
"""
if batch_size < 1:
raise ValueError("batch_size >= 1")

# 1. Load chunk context
text_context_map = await self.aload_chunk_context(texts)

# Pre-allocate results list to maintain order
graphs_list: List[List[Graph]] = [None] * len(texts)
total_batches = (len(texts) + batch_size - 1) // batch_size

for batch_idx in range(total_batches):
start_idx = batch_idx * batch_size
end_idx = min((batch_idx + 1) * batch_size, len(texts))
batch_texts = texts[start_idx:end_idx]

# 2. Create tasks with their original indices
extraction_tasks = [
(
idx,
self._extract(text, text_context_map[text], limit),
)
for idx, text in enumerate(batch_texts, start=start_idx)
]

# 3. Process extraction in parallel while keeping track of indices
batch_results = await asyncio.gather(
*(task for _, task in extraction_tasks)
)

# 4. Place results in the correct positions
for (idx, _), graphs in zip(extraction_tasks, batch_results):
graphs_list[idx] = graphs

assert all(x is not None for x in graphs_list), "All positions should be filled"
return graphs_list

def _parse_response(self, text: str, limit: Optional[int] = None) -> List[Graph]:
graph = MemoryGraph()
edge_count = 0
Expand Down
28 changes: 28 additions & 0 deletions dbgpt/rag/transformer/llm_extractor.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
"""TripletExtractor class."""

import asyncio
import logging
from abc import ABC, abstractmethod
from typing import List, Optional
Expand All @@ -22,6 +24,32 @@ async def extract(self, text: str, limit: Optional[int] = None) -> List:
"""Extract by LLM."""
return await self._extract(text, None, limit)

async def batch_extract(
self,
texts: List[str],
batch_size: int = 1,
limit: Optional[int] = None,
) -> List:
"""Batch extract by LLM."""
if batch_size < 1:
raise ValueError("batch_size >= 1")

results = []

for i in range(0, len(texts), batch_size):
batch_texts = texts[i : i + batch_size]

# Create tasks for current batch
extraction_tasks = [
self._extract(text, None, limit) for text in batch_texts
]

# Execute batch concurrently and wait for all to complete
batch_results = await asyncio.gather(*extraction_tasks)
results.extend(batch_results)

return results

async def _extract(
self, text: str, history: str = None, limit: Optional[int] = None
) -> List:
Expand Down
3 changes: 2 additions & 1 deletion dbgpt/rag/transformer/triplet_extractor.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""TripletExtractor class."""

import logging
import re
from typing import Any, List, Optional, Tuple
Expand All @@ -12,7 +13,7 @@
"Some text is provided below. Given the text, "
"extract up to knowledge triplets as more as possible "
"in the form of (subject, predicate, object).\n"
"Avoid stopwords.\n"
"Avoid stopwords. The subject, predicate, object can not be none.\n"
"---------------------\n"
"Example:\n"
"Text: Alice is Bob's mother.\n"
Expand Down
8 changes: 0 additions & 8 deletions dbgpt/storage/graph_store/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,6 @@ class GraphStoreConfig(BaseModel):
default=False,
description="Enable graph community summary or not.",
)
document_graph_enabled: bool = Field(
default=True,
description="Enable document graph search or not.",
)
triplet_graph_enabled: bool = Field(
default=True,
description="Enable knowledge graph search or not.",
)


class GraphStoreBase(ABC):
Expand Down
8 changes: 0 additions & 8 deletions dbgpt/storage/graph_store/tugraph_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,14 +83,6 @@ def __init__(self, config: TuGraphStoreConfig) -> None:
os.getenv("GRAPH_COMMUNITY_SUMMARY_ENABLED", "").lower() == "true"
or config.enable_summary
)
self._enable_document_graph = (
os.getenv("DOCUMENT_GRAPH_ENABLED", "").lower() == "true"
or config.document_graph_enabled
)
self._enable_triplet_graph = (
os.getenv("TRIPLET_GRAPH_ENABLED", "").lower() == "true"
or config.triplet_graph_enabled
)
self._plugin_names = (
os.getenv("TUGRAPH_PLUGIN_NAMES", "leiden").split(",")
or config.plugin_names
Expand Down
Loading

0 comments on commit 25d47ce

Please sign in to comment.