Skip to content

Commit

Permalink
fix redundant embedding
Browse files Browse the repository at this point in the history
  • Loading branch information
seehi committed Apr 23, 2024
1 parent e212e7b commit 263d799
Show file tree
Hide file tree
Showing 6 changed files with 137 additions and 73 deletions.
63 changes: 49 additions & 14 deletions metagpt/rag/engines/simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import os
from typing import Any, Optional, Union

from llama_index.core import SimpleDirectoryReader, VectorStoreIndex
from llama_index.core import SimpleDirectoryReader
from llama_index.core.callbacks.base import CallbackManager
from llama_index.core.embeddings import BaseEmbedding
from llama_index.core.embeddings.mock_embed_model import MockEmbedding
Expand Down Expand Up @@ -63,15 +63,15 @@ def __init__(
response_synthesizer: Optional[BaseSynthesizer] = None,
node_postprocessors: Optional[list[BaseNodePostprocessor]] = None,
callback_manager: Optional[CallbackManager] = None,
index: Optional[BaseIndex] = None,
transformations: Optional[list[TransformComponent]] = None,
) -> None:
super().__init__(
retriever=retriever,
response_synthesizer=response_synthesizer,
node_postprocessors=node_postprocessors,
callback_manager=callback_manager,
)
self.index = index
self._transformations = transformations or self._default_transformations()

@classmethod
def from_docs(
Expand Down Expand Up @@ -103,12 +103,17 @@ def from_docs(
documents = SimpleDirectoryReader(input_dir=input_dir, input_files=input_files).load_data()
cls._fix_document_metadata(documents)

index = VectorStoreIndex.from_documents(
documents=documents,
transformations=transformations or [SentenceSplitter()],
embed_model=cls._resolve_embed_model(embed_model, retriever_configs),
transformations = transformations or cls._default_transformations()
nodes = run_transformations(documents, transformations=transformations)

return cls._from_nodes(
nodes=nodes,
transformations=transformations,
embed_model=embed_model,
llm=llm,
retriever_configs=retriever_configs,
ranker_configs=ranker_configs,
)
return cls._from_index(index, llm=llm, retriever_configs=retriever_configs, ranker_configs=ranker_configs)

@classmethod
def from_objs(
Expand Down Expand Up @@ -137,12 +142,15 @@ def from_objs(
raise ValueError("In BM25RetrieverConfig, Objs must not be empty.")

nodes = [ObjectNode(text=obj.rag_key(), metadata=ObjectNode.get_obj_metadata(obj)) for obj in objs]
index = VectorStoreIndex(

return cls._from_nodes(
nodes=nodes,
transformations=transformations or [SentenceSplitter()],
embed_model=cls._resolve_embed_model(embed_model, retriever_configs),
transformations=transformations,
embed_model=embed_model,
llm=llm,
retriever_configs=retriever_configs,
ranker_configs=ranker_configs,
)
return cls._from_index(index, llm=llm, retriever_configs=retriever_configs, ranker_configs=ranker_configs)

@classmethod
def from_index(
Expand Down Expand Up @@ -183,7 +191,7 @@ def add_docs(self, input_files: list[str]):
documents = SimpleDirectoryReader(input_files=input_files).load_data()
self._fix_document_metadata(documents)

nodes = run_transformations(documents, transformations=self.index._transformations)
nodes = run_transformations(documents, transformations=self._transformations)
self._save_nodes(nodes)

def add_objs(self, objs: list[RAGObject]):
Expand All @@ -199,6 +207,29 @@ def persist(self, persist_dir: Union[str, os.PathLike], **kwargs):

self._persist(str(persist_dir), **kwargs)

@classmethod
def _from_nodes(
cls,
nodes: list[BaseNode],
transformations: Optional[list[TransformComponent]] = None,
embed_model: BaseEmbedding = None,
llm: LLM = None,
retriever_configs: list[BaseRetrieverConfig] = None,
ranker_configs: list[BaseRankerConfig] = None,
) -> "SimpleEngine":
embed_model = cls._resolve_embed_model(embed_model, retriever_configs)
llm = llm or get_rag_llm()

retriever = get_retriever(configs=retriever_configs, nodes=nodes, embed_model=embed_model)
rankers = get_rankers(configs=ranker_configs, llm=llm) # Default []

return cls(
retriever=retriever,
node_postprocessors=rankers,
response_synthesizer=get_response_synthesizer(llm=llm),
transformations=transformations,
)

@classmethod
def _from_index(
cls,
Expand All @@ -208,14 +239,14 @@ def _from_index(
ranker_configs: list[BaseRankerConfig] = None,
) -> "SimpleEngine":
llm = llm or get_rag_llm()

retriever = get_retriever(configs=retriever_configs, index=index) # Default index.as_retriever
rankers = get_rankers(configs=ranker_configs, llm=llm) # Default []

return cls(
retriever=retriever,
node_postprocessors=rankers,
response_synthesizer=get_response_synthesizer(llm=llm),
index=index,
)

def _ensure_retriever_modifiable(self):
Expand Down Expand Up @@ -266,3 +297,7 @@ def _resolve_embed_model(embed_model: BaseEmbedding = None, configs: list[Any] =
return MockEmbedding(embed_dim=1)

return embed_model or get_rag_embedding()

@staticmethod
def _default_transformations():
return [SentenceSplitter()]
17 changes: 11 additions & 6 deletions metagpt/rag/factories/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,19 +36,26 @@ class ConfigBasedFactory(GenericFactory):
"""Designed to get objects based on object type."""

def get_instance(self, key: Any, **kwargs) -> Any:
"""Key is config, such as a pydantic model.
"""Get instance by the type of key.
Call func by the type of key, and the key will be passed to func.
Key is config, such as a pydantic model, call func by the type of key, and the key will be passed to func.
Raise Exception if key not found.
"""
creator = self._creators.get(type(key))
if creator:
return creator(key, **kwargs)

self._raise_for_key(key)

def _raise_for_key(self, key: Any):
raise ValueError(f"Unknown config: `{type(key)}`, {key}")

@staticmethod
def _val_from_config_or_kwargs(key: str, config: object = None, **kwargs) -> Any:
"""It prioritizes the configuration object's value unless it is None, in which case it looks into kwargs."""
"""It prioritizes the configuration object's value unless it is None, in which case it looks into kwargs.
Return None if not found.
"""
if config is not None and hasattr(config, key):
val = getattr(config, key)
if val is not None:
Expand All @@ -57,6 +64,4 @@ def _val_from_config_or_kwargs(key: str, config: object = None, **kwargs) -> Any
if key in kwargs:
return kwargs[key]

raise KeyError(
f"The key '{key}' is required but not provided in either configuration object or keyword arguments."
)
return None
67 changes: 47 additions & 20 deletions metagpt/rag/factories/retriever.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
"""RAG Retriever Factory."""

import copy

import chromadb
import faiss
from llama_index.core import StorageContext, VectorStoreIndex
from llama_index.core.embeddings import BaseEmbedding
from llama_index.core.schema import BaseNode
from llama_index.core.vector_stores.types import BasePydanticVectorStore
from llama_index.vector_stores.chroma import ChromaVectorStore
from llama_index.vector_stores.elasticsearch import ElasticsearchStore
Expand All @@ -24,7 +25,6 @@
ElasticsearchKeywordRetrieverConfig,
ElasticsearchRetrieverConfig,
FAISSRetrieverConfig,
IndexRetrieverConfig,
)


Expand Down Expand Up @@ -54,48 +54,75 @@ def get_retriever(self, configs: list[BaseRetrieverConfig] = None, **kwargs) ->
return SimpleHybridRetriever(*retrievers) if len(retrievers) > 1 else retrievers[0]

def _create_default(self, **kwargs) -> RAGRetriever:
return self._extract_index(**kwargs).as_retriever()
index = self._extract_index(None, **kwargs) or self._build_default_index(**kwargs)

return index.as_retriever()

def _create_faiss_retriever(self, config: FAISSRetrieverConfig, **kwargs) -> FAISSRetriever:
vector_store = FaissVectorStore(faiss_index=faiss.IndexFlatL2(config.dimensions))
config.index = self._build_index_from_vector_store(config, vector_store, **kwargs)
config.index = self._extract_index(config, **kwargs) or self._build_faiss_index(config, **kwargs)

return FAISSRetriever(**config.model_dump())

def _create_bm25_retriever(self, config: BM25RetrieverConfig, **kwargs) -> DynamicBM25Retriever:
config.index = copy.deepcopy(self._extract_index(config, **kwargs))
nodes = self._extract_nodes(config, **kwargs)

return DynamicBM25Retriever(nodes=list(config.index.docstore.docs.values()), **config.model_dump())
return DynamicBM25Retriever(nodes=nodes, **config.model_dump())

def _create_chroma_retriever(self, config: ChromaRetrieverConfig, **kwargs) -> ChromaRetriever:
db = chromadb.PersistentClient(path=str(config.persist_path))
chroma_collection = db.get_or_create_collection(config.collection_name, metadata=config.metadata)

vector_store = ChromaVectorStore(chroma_collection=chroma_collection)
config.index = self._build_index_from_vector_store(config, vector_store, **kwargs)
config.index = self._build_chroma_index(config, **kwargs)

return ChromaRetriever(**config.model_dump())

def _create_es_retriever(self, config: ElasticsearchRetrieverConfig, **kwargs) -> ElasticsearchRetriever:
vector_store = ElasticsearchStore(**config.store_config.model_dump())
config.index = self._build_index_from_vector_store(config, vector_store, **kwargs)
config.index = self._build_es_index(config, **kwargs)

return ElasticsearchRetriever(**config.model_dump())

def _extract_index(self, config: BaseRetrieverConfig = None, **kwargs) -> VectorStoreIndex:
return self._val_from_config_or_kwargs("index", config, **kwargs)

def _extract_nodes(self, config: BaseRetrieverConfig = None, **kwargs) -> list[BaseNode]:
return self._val_from_config_or_kwargs("nodes", config, **kwargs)

def _extract_embed_model(self, config: BaseRetrieverConfig = None, **kwargs) -> BaseEmbedding:
return self._val_from_config_or_kwargs("embed_model", config, **kwargs)

def _build_default_index(self, **kwargs) -> VectorStoreIndex:
index = VectorStoreIndex(
nodes=self._extract_nodes(**kwargs),
embed_model=self._extract_embed_model(**kwargs),
)

return index

def _build_faiss_index(self, config: FAISSRetrieverConfig, **kwargs) -> VectorStoreIndex:
vector_store = FaissVectorStore(faiss_index=faiss.IndexFlatL2(config.dimensions))

return self._build_index_from_vector_store(config, vector_store, **kwargs)

def _build_chroma_index(self, config: ChromaRetrieverConfig, **kwargs) -> VectorStoreIndex:
db = chromadb.PersistentClient(path=str(config.persist_path))
chroma_collection = db.get_or_create_collection(config.collection_name, metadata=config.metadata)
vector_store = ChromaVectorStore(chroma_collection=chroma_collection)

return self._build_index_from_vector_store(config, vector_store, **kwargs)

def _build_es_index(self, config: ElasticsearchRetrieverConfig, **kwargs) -> VectorStoreIndex:
vector_store = ElasticsearchStore(**config.store_config.model_dump())

return self._build_index_from_vector_store(config, vector_store, **kwargs)

def _build_index_from_vector_store(
self, config: IndexRetrieverConfig, vector_store: BasePydanticVectorStore, **kwargs
self, config: BaseRetrieverConfig, vector_store: BasePydanticVectorStore, **kwargs
) -> VectorStoreIndex:
storage_context = StorageContext.from_defaults(vector_store=vector_store)
old_index = self._extract_index(config, **kwargs)
new_index = VectorStoreIndex(
nodes=list(old_index.docstore.docs.values()),
index = VectorStoreIndex(
nodes=self._extract_nodes(config, **kwargs),
storage_context=storage_context,
embed_model=old_index._embed_model,
embed_model=self._extract_embed_model(config, **kwargs),
)
return new_index

return index


get_retriever = RetrieverFactory().get_retriever
24 changes: 7 additions & 17 deletions tests/metagpt/rag/engines/test_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,6 @@ def mock_embedding(self):
def mock_simple_directory_reader(self, mocker):
return mocker.patch("metagpt.rag.engines.simple.SimpleDirectoryReader")

@pytest.fixture
def mock_vector_store_index(self, mocker):
return mocker.patch("metagpt.rag.engines.simple.VectorStoreIndex.from_documents")

@pytest.fixture
def mock_get_retriever(self, mocker):
return mocker.patch("metagpt.rag.engines.simple.get_retriever")
Expand All @@ -45,7 +41,6 @@ def test_from_docs(
self,
mocker,
mock_simple_directory_reader,
mock_vector_store_index,
mock_get_retriever,
mock_get_rankers,
mock_get_response_synthesizer,
Expand Down Expand Up @@ -81,11 +76,8 @@ def test_from_docs(

# Assert
mock_simple_directory_reader.assert_called_once_with(input_dir=input_dir, input_files=input_files)
mock_vector_store_index.assert_called_once()
mock_get_retriever.assert_called_once_with(
configs=retriever_configs, index=mock_vector_store_index.return_value
)
mock_get_rankers.assert_called_once_with(configs=ranker_configs, llm=llm)
mock_get_retriever.assert_called_once()
mock_get_rankers.assert_called_once()
mock_get_response_synthesizer.assert_called_once_with(llm=llm)
assert isinstance(engine, SimpleEngine)

Expand Down Expand Up @@ -119,7 +111,7 @@ def model_dump_json(self):

# Assert
assert isinstance(engine, SimpleEngine)
assert engine.index is not None
assert engine._transformations is not None

def test_from_objs_with_bm25_config(self):
# Setup
Expand All @@ -137,6 +129,7 @@ def test_from_objs_with_bm25_config(self):
def test_from_index(self, mocker, mock_llm, mock_embedding):
# Mock
mock_index = mocker.MagicMock(spec=VectorStoreIndex)
mock_index.as_retriever.return_value = "retriever"
mock_get_index = mocker.patch("metagpt.rag.engines.simple.get_index")
mock_get_index.return_value = mock_index

Expand All @@ -149,7 +142,7 @@ def test_from_index(self, mocker, mock_llm, mock_embedding):

# Assert
assert isinstance(engine, SimpleEngine)
assert engine.index is mock_index
assert engine._retriever == "retriever"

@pytest.mark.asyncio
async def test_asearch(self, mocker):
Expand Down Expand Up @@ -200,14 +193,11 @@ def test_add_docs(self, mocker):

mock_retriever = mocker.MagicMock(spec=ModifiableRAGRetriever)

mock_index = mocker.MagicMock(spec=VectorStoreIndex)
mock_index._transformations = mocker.MagicMock()

mock_run_transformations = mocker.patch("metagpt.rag.engines.simple.run_transformations")
mock_run_transformations.return_value = ["node1", "node2"]

# Setup
engine = SimpleEngine(retriever=mock_retriever, index=mock_index)
engine = SimpleEngine(retriever=mock_retriever)
input_files = ["test_file1", "test_file2"]

# Exec
Expand All @@ -230,7 +220,7 @@ def model_dump_json(self):
return ""

objs = [CustomTextNode(text=f"text_{i}", metadata={"obj": f"obj_{i}"}) for i in range(2)]
engine = SimpleEngine(retriever=mock_retriever, index=mocker.MagicMock())
engine = SimpleEngine(retriever=mock_retriever)

# Exec
engine.add_objs(objs=objs)
Expand Down
5 changes: 2 additions & 3 deletions tests/metagpt/rag/factories/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,5 @@ def test_val_from_config_or_kwargs_fallback_to_kwargs(self):
def test_val_from_config_or_kwargs_key_error(self):
# Test KeyError when the key is not found in both config object and kwargs
config = DummyConfig(name=None)
with pytest.raises(KeyError) as exc_info:
ConfigBasedFactory._val_from_config_or_kwargs("missing_key", config)
assert "The key 'missing_key' is required but not provided" in str(exc_info.value)
val = ConfigBasedFactory._val_from_config_or_kwargs("missing_key", config)
assert val is None
Loading

0 comments on commit 263d799

Please sign in to comment.