From 263d79980e10993563db50582a3ecdde26020a61 Mon Sep 17 00:00:00 2001 From: seehi <6580@pm.me> Date: Tue, 23 Apr 2024 21:48:50 +0800 Subject: [PATCH] fix redundant embedding --- metagpt/rag/engines/simple.py | 63 +++++++++++++---- metagpt/rag/factories/base.py | 17 +++-- metagpt/rag/factories/retriever.py | 67 +++++++++++++------ tests/metagpt/rag/engines/test_simple.py | 24 ++----- tests/metagpt/rag/factories/test_base.py | 5 +- tests/metagpt/rag/factories/test_retriever.py | 34 ++++++---- 6 files changed, 137 insertions(+), 73 deletions(-) diff --git a/metagpt/rag/engines/simple.py b/metagpt/rag/engines/simple.py index 34f925249..c237dcf69 100644 --- a/metagpt/rag/engines/simple.py +++ b/metagpt/rag/engines/simple.py @@ -4,7 +4,7 @@ import os from typing import Any, Optional, Union -from llama_index.core import SimpleDirectoryReader, VectorStoreIndex +from llama_index.core import SimpleDirectoryReader from llama_index.core.callbacks.base import CallbackManager from llama_index.core.embeddings import BaseEmbedding from llama_index.core.embeddings.mock_embed_model import MockEmbedding @@ -63,7 +63,7 @@ def __init__( response_synthesizer: Optional[BaseSynthesizer] = None, node_postprocessors: Optional[list[BaseNodePostprocessor]] = None, callback_manager: Optional[CallbackManager] = None, - index: Optional[BaseIndex] = None, + transformations: Optional[list[TransformComponent]] = None, ) -> None: super().__init__( retriever=retriever, @@ -71,7 +71,7 @@ def __init__( node_postprocessors=node_postprocessors, callback_manager=callback_manager, ) - self.index = index + self._transformations = transformations or self._default_transformations() @classmethod def from_docs( @@ -103,12 +103,17 @@ def from_docs( documents = SimpleDirectoryReader(input_dir=input_dir, input_files=input_files).load_data() cls._fix_document_metadata(documents) - index = VectorStoreIndex.from_documents( - documents=documents, - transformations=transformations or [SentenceSplitter()], - embed_model=cls._resolve_embed_model(embed_model, retriever_configs), + transformations = transformations or cls._default_transformations() + nodes = run_transformations(documents, transformations=transformations) + + return cls._from_nodes( + nodes=nodes, + transformations=transformations, + embed_model=embed_model, + llm=llm, + retriever_configs=retriever_configs, + ranker_configs=ranker_configs, ) - return cls._from_index(index, llm=llm, retriever_configs=retriever_configs, ranker_configs=ranker_configs) @classmethod def from_objs( @@ -137,12 +142,15 @@ def from_objs( raise ValueError("In BM25RetrieverConfig, Objs must not be empty.") nodes = [ObjectNode(text=obj.rag_key(), metadata=ObjectNode.get_obj_metadata(obj)) for obj in objs] - index = VectorStoreIndex( + + return cls._from_nodes( nodes=nodes, - transformations=transformations or [SentenceSplitter()], - embed_model=cls._resolve_embed_model(embed_model, retriever_configs), + transformations=transformations, + embed_model=embed_model, + llm=llm, + retriever_configs=retriever_configs, + ranker_configs=ranker_configs, ) - return cls._from_index(index, llm=llm, retriever_configs=retriever_configs, ranker_configs=ranker_configs) @classmethod def from_index( @@ -183,7 +191,7 @@ def add_docs(self, input_files: list[str]): documents = SimpleDirectoryReader(input_files=input_files).load_data() self._fix_document_metadata(documents) - nodes = run_transformations(documents, transformations=self.index._transformations) + nodes = run_transformations(documents, transformations=self._transformations) self._save_nodes(nodes) def add_objs(self, objs: list[RAGObject]): @@ -199,6 +207,29 @@ def persist(self, persist_dir: Union[str, os.PathLike], **kwargs): self._persist(str(persist_dir), **kwargs) + @classmethod + def _from_nodes( + cls, + nodes: list[BaseNode], + transformations: Optional[list[TransformComponent]] = None, + embed_model: BaseEmbedding = None, + llm: LLM = None, + retriever_configs: list[BaseRetrieverConfig] = None, + ranker_configs: list[BaseRankerConfig] = None, + ) -> "SimpleEngine": + embed_model = cls._resolve_embed_model(embed_model, retriever_configs) + llm = llm or get_rag_llm() + + retriever = get_retriever(configs=retriever_configs, nodes=nodes, embed_model=embed_model) + rankers = get_rankers(configs=ranker_configs, llm=llm) # Default [] + + return cls( + retriever=retriever, + node_postprocessors=rankers, + response_synthesizer=get_response_synthesizer(llm=llm), + transformations=transformations, + ) + @classmethod def _from_index( cls, @@ -208,6 +239,7 @@ def _from_index( ranker_configs: list[BaseRankerConfig] = None, ) -> "SimpleEngine": llm = llm or get_rag_llm() + retriever = get_retriever(configs=retriever_configs, index=index) # Default index.as_retriever rankers = get_rankers(configs=ranker_configs, llm=llm) # Default [] @@ -215,7 +247,6 @@ def _from_index( retriever=retriever, node_postprocessors=rankers, response_synthesizer=get_response_synthesizer(llm=llm), - index=index, ) def _ensure_retriever_modifiable(self): @@ -266,3 +297,7 @@ def _resolve_embed_model(embed_model: BaseEmbedding = None, configs: list[Any] = return MockEmbedding(embed_dim=1) return embed_model or get_rag_embedding() + + @staticmethod + def _default_transformations(): + return [SentenceSplitter()] diff --git a/metagpt/rag/factories/base.py b/metagpt/rag/factories/base.py index fcfec03ec..e58643efe 100644 --- a/metagpt/rag/factories/base.py +++ b/metagpt/rag/factories/base.py @@ -36,19 +36,26 @@ class ConfigBasedFactory(GenericFactory): """Designed to get objects based on object type.""" def get_instance(self, key: Any, **kwargs) -> Any: - """Key is config, such as a pydantic model. + """Get instance by the type of key. - Call func by the type of key, and the key will be passed to func. + Key is config, such as a pydantic model, call func by the type of key, and the key will be passed to func. + Raise Exception if key not found. """ creator = self._creators.get(type(key)) if creator: return creator(key, **kwargs) + self._raise_for_key(key) + + def _raise_for_key(self, key: Any): raise ValueError(f"Unknown config: `{type(key)}`, {key}") @staticmethod def _val_from_config_or_kwargs(key: str, config: object = None, **kwargs) -> Any: - """It prioritizes the configuration object's value unless it is None, in which case it looks into kwargs.""" + """It prioritizes the configuration object's value unless it is None, in which case it looks into kwargs. + + Return None if not found. + """ if config is not None and hasattr(config, key): val = getattr(config, key) if val is not None: @@ -57,6 +64,4 @@ def _val_from_config_or_kwargs(key: str, config: object = None, **kwargs) -> Any if key in kwargs: return kwargs[key] - raise KeyError( - f"The key '{key}' is required but not provided in either configuration object or keyword arguments." - ) + return None diff --git a/metagpt/rag/factories/retriever.py b/metagpt/rag/factories/retriever.py index 68f2c2313..dd6261d52 100644 --- a/metagpt/rag/factories/retriever.py +++ b/metagpt/rag/factories/retriever.py @@ -1,10 +1,11 @@ """RAG Retriever Factory.""" -import copy import chromadb import faiss from llama_index.core import StorageContext, VectorStoreIndex +from llama_index.core.embeddings import BaseEmbedding +from llama_index.core.schema import BaseNode from llama_index.core.vector_stores.types import BasePydanticVectorStore from llama_index.vector_stores.chroma import ChromaVectorStore from llama_index.vector_stores.elasticsearch import ElasticsearchStore @@ -24,7 +25,6 @@ ElasticsearchKeywordRetrieverConfig, ElasticsearchRetrieverConfig, FAISSRetrieverConfig, - IndexRetrieverConfig, ) @@ -54,48 +54,75 @@ def get_retriever(self, configs: list[BaseRetrieverConfig] = None, **kwargs) -> return SimpleHybridRetriever(*retrievers) if len(retrievers) > 1 else retrievers[0] def _create_default(self, **kwargs) -> RAGRetriever: - return self._extract_index(**kwargs).as_retriever() + index = self._extract_index(None, **kwargs) or self._build_default_index(**kwargs) + + return index.as_retriever() def _create_faiss_retriever(self, config: FAISSRetrieverConfig, **kwargs) -> FAISSRetriever: - vector_store = FaissVectorStore(faiss_index=faiss.IndexFlatL2(config.dimensions)) - config.index = self._build_index_from_vector_store(config, vector_store, **kwargs) + config.index = self._extract_index(config, **kwargs) or self._build_faiss_index(config, **kwargs) return FAISSRetriever(**config.model_dump()) def _create_bm25_retriever(self, config: BM25RetrieverConfig, **kwargs) -> DynamicBM25Retriever: - config.index = copy.deepcopy(self._extract_index(config, **kwargs)) + nodes = self._extract_nodes(config, **kwargs) - return DynamicBM25Retriever(nodes=list(config.index.docstore.docs.values()), **config.model_dump()) + return DynamicBM25Retriever(nodes=nodes, **config.model_dump()) def _create_chroma_retriever(self, config: ChromaRetrieverConfig, **kwargs) -> ChromaRetriever: - db = chromadb.PersistentClient(path=str(config.persist_path)) - chroma_collection = db.get_or_create_collection(config.collection_name, metadata=config.metadata) - - vector_store = ChromaVectorStore(chroma_collection=chroma_collection) - config.index = self._build_index_from_vector_store(config, vector_store, **kwargs) + config.index = self._build_chroma_index(config, **kwargs) return ChromaRetriever(**config.model_dump()) def _create_es_retriever(self, config: ElasticsearchRetrieverConfig, **kwargs) -> ElasticsearchRetriever: - vector_store = ElasticsearchStore(**config.store_config.model_dump()) - config.index = self._build_index_from_vector_store(config, vector_store, **kwargs) + config.index = self._build_es_index(config, **kwargs) return ElasticsearchRetriever(**config.model_dump()) def _extract_index(self, config: BaseRetrieverConfig = None, **kwargs) -> VectorStoreIndex: return self._val_from_config_or_kwargs("index", config, **kwargs) + def _extract_nodes(self, config: BaseRetrieverConfig = None, **kwargs) -> list[BaseNode]: + return self._val_from_config_or_kwargs("nodes", config, **kwargs) + + def _extract_embed_model(self, config: BaseRetrieverConfig = None, **kwargs) -> BaseEmbedding: + return self._val_from_config_or_kwargs("embed_model", config, **kwargs) + + def _build_default_index(self, **kwargs) -> VectorStoreIndex: + index = VectorStoreIndex( + nodes=self._extract_nodes(**kwargs), + embed_model=self._extract_embed_model(**kwargs), + ) + + return index + + def _build_faiss_index(self, config: FAISSRetrieverConfig, **kwargs) -> VectorStoreIndex: + vector_store = FaissVectorStore(faiss_index=faiss.IndexFlatL2(config.dimensions)) + + return self._build_index_from_vector_store(config, vector_store, **kwargs) + + def _build_chroma_index(self, config: ChromaRetrieverConfig, **kwargs) -> VectorStoreIndex: + db = chromadb.PersistentClient(path=str(config.persist_path)) + chroma_collection = db.get_or_create_collection(config.collection_name, metadata=config.metadata) + vector_store = ChromaVectorStore(chroma_collection=chroma_collection) + + return self._build_index_from_vector_store(config, vector_store, **kwargs) + + def _build_es_index(self, config: ElasticsearchRetrieverConfig, **kwargs) -> VectorStoreIndex: + vector_store = ElasticsearchStore(**config.store_config.model_dump()) + + return self._build_index_from_vector_store(config, vector_store, **kwargs) + def _build_index_from_vector_store( - self, config: IndexRetrieverConfig, vector_store: BasePydanticVectorStore, **kwargs + self, config: BaseRetrieverConfig, vector_store: BasePydanticVectorStore, **kwargs ) -> VectorStoreIndex: storage_context = StorageContext.from_defaults(vector_store=vector_store) - old_index = self._extract_index(config, **kwargs) - new_index = VectorStoreIndex( - nodes=list(old_index.docstore.docs.values()), + index = VectorStoreIndex( + nodes=self._extract_nodes(config, **kwargs), storage_context=storage_context, - embed_model=old_index._embed_model, + embed_model=self._extract_embed_model(config, **kwargs), ) - return new_index + + return index get_retriever = RetrieverFactory().get_retriever diff --git a/tests/metagpt/rag/engines/test_simple.py b/tests/metagpt/rag/engines/test_simple.py index 9262ccb07..8c7a15be2 100644 --- a/tests/metagpt/rag/engines/test_simple.py +++ b/tests/metagpt/rag/engines/test_simple.py @@ -25,10 +25,6 @@ def mock_embedding(self): def mock_simple_directory_reader(self, mocker): return mocker.patch("metagpt.rag.engines.simple.SimpleDirectoryReader") - @pytest.fixture - def mock_vector_store_index(self, mocker): - return mocker.patch("metagpt.rag.engines.simple.VectorStoreIndex.from_documents") - @pytest.fixture def mock_get_retriever(self, mocker): return mocker.patch("metagpt.rag.engines.simple.get_retriever") @@ -45,7 +41,6 @@ def test_from_docs( self, mocker, mock_simple_directory_reader, - mock_vector_store_index, mock_get_retriever, mock_get_rankers, mock_get_response_synthesizer, @@ -81,11 +76,8 @@ def test_from_docs( # Assert mock_simple_directory_reader.assert_called_once_with(input_dir=input_dir, input_files=input_files) - mock_vector_store_index.assert_called_once() - mock_get_retriever.assert_called_once_with( - configs=retriever_configs, index=mock_vector_store_index.return_value - ) - mock_get_rankers.assert_called_once_with(configs=ranker_configs, llm=llm) + mock_get_retriever.assert_called_once() + mock_get_rankers.assert_called_once() mock_get_response_synthesizer.assert_called_once_with(llm=llm) assert isinstance(engine, SimpleEngine) @@ -119,7 +111,7 @@ def model_dump_json(self): # Assert assert isinstance(engine, SimpleEngine) - assert engine.index is not None + assert engine._transformations is not None def test_from_objs_with_bm25_config(self): # Setup @@ -137,6 +129,7 @@ def test_from_objs_with_bm25_config(self): def test_from_index(self, mocker, mock_llm, mock_embedding): # Mock mock_index = mocker.MagicMock(spec=VectorStoreIndex) + mock_index.as_retriever.return_value = "retriever" mock_get_index = mocker.patch("metagpt.rag.engines.simple.get_index") mock_get_index.return_value = mock_index @@ -149,7 +142,7 @@ def test_from_index(self, mocker, mock_llm, mock_embedding): # Assert assert isinstance(engine, SimpleEngine) - assert engine.index is mock_index + assert engine._retriever == "retriever" @pytest.mark.asyncio async def test_asearch(self, mocker): @@ -200,14 +193,11 @@ def test_add_docs(self, mocker): mock_retriever = mocker.MagicMock(spec=ModifiableRAGRetriever) - mock_index = mocker.MagicMock(spec=VectorStoreIndex) - mock_index._transformations = mocker.MagicMock() - mock_run_transformations = mocker.patch("metagpt.rag.engines.simple.run_transformations") mock_run_transformations.return_value = ["node1", "node2"] # Setup - engine = SimpleEngine(retriever=mock_retriever, index=mock_index) + engine = SimpleEngine(retriever=mock_retriever) input_files = ["test_file1", "test_file2"] # Exec @@ -230,7 +220,7 @@ def model_dump_json(self): return "" objs = [CustomTextNode(text=f"text_{i}", metadata={"obj": f"obj_{i}"}) for i in range(2)] - engine = SimpleEngine(retriever=mock_retriever, index=mocker.MagicMock()) + engine = SimpleEngine(retriever=mock_retriever) # Exec engine.add_objs(objs=objs) diff --git a/tests/metagpt/rag/factories/test_base.py b/tests/metagpt/rag/factories/test_base.py index 1d41e1872..0b0a44976 100644 --- a/tests/metagpt/rag/factories/test_base.py +++ b/tests/metagpt/rag/factories/test_base.py @@ -97,6 +97,5 @@ def test_val_from_config_or_kwargs_fallback_to_kwargs(self): def test_val_from_config_or_kwargs_key_error(self): # Test KeyError when the key is not found in both config object and kwargs config = DummyConfig(name=None) - with pytest.raises(KeyError) as exc_info: - ConfigBasedFactory._val_from_config_or_kwargs("missing_key", config) - assert "The key 'missing_key' is required but not provided" in str(exc_info.value) + val = ConfigBasedFactory._val_from_config_or_kwargs("missing_key", config) + assert val is None diff --git a/tests/metagpt/rag/factories/test_retriever.py b/tests/metagpt/rag/factories/test_retriever.py index ef1cef7e0..a70639f55 100644 --- a/tests/metagpt/rag/factories/test_retriever.py +++ b/tests/metagpt/rag/factories/test_retriever.py @@ -1,6 +1,8 @@ import faiss import pytest from llama_index.core import VectorStoreIndex +from llama_index.core.embeddings import MockEmbedding +from llama_index.core.schema import TextNode from llama_index.vector_stores.chroma import ChromaVectorStore from llama_index.vector_stores.elasticsearch import ElasticsearchStore @@ -43,6 +45,14 @@ def mock_chroma_vector_store(self, mocker): def mock_es_vector_store(self, mocker): return mocker.MagicMock(spec=ElasticsearchStore) + @pytest.fixture + def mock_nodes(self, mocker): + return [TextNode(text="msg")] + + @pytest.fixture + def mock_embedding(self): + return MockEmbedding(embed_dim=1) + def test_get_retriever_with_faiss_config(self, mock_faiss_index, mocker, mock_vector_store_index): mock_config = FAISSRetrieverConfig(dimensions=128) mocker.patch("faiss.IndexFlatL2", return_value=mock_faiss_index) @@ -52,42 +62,40 @@ def test_get_retriever_with_faiss_config(self, mock_faiss_index, mocker, mock_ve assert isinstance(retriever, FAISSRetriever) - def test_get_retriever_with_bm25_config(self, mocker, mock_vector_store_index): + def test_get_retriever_with_bm25_config(self, mocker, mock_nodes): mock_config = BM25RetrieverConfig() mocker.patch("rank_bm25.BM25Okapi.__init__", return_value=None) - mocker.patch.object(self.retriever_factory, "_extract_index", return_value=mock_vector_store_index) - retriever = self.retriever_factory.get_retriever(configs=[mock_config]) + retriever = self.retriever_factory.get_retriever(configs=[mock_config], nodes=mock_nodes) assert isinstance(retriever, DynamicBM25Retriever) - def test_get_retriever_with_multiple_configs_returns_hybrid(self, mocker, mock_vector_store_index): - mock_faiss_config = FAISSRetrieverConfig(dimensions=128) + def test_get_retriever_with_multiple_configs_returns_hybrid(self, mocker, mock_nodes, mock_embedding): + mock_faiss_config = FAISSRetrieverConfig(dimensions=1) mock_bm25_config = BM25RetrieverConfig() mocker.patch("rank_bm25.BM25Okapi.__init__", return_value=None) - mocker.patch.object(self.retriever_factory, "_extract_index", return_value=mock_vector_store_index) - retriever = self.retriever_factory.get_retriever(configs=[mock_faiss_config, mock_bm25_config]) + retriever = self.retriever_factory.get_retriever( + configs=[mock_faiss_config, mock_bm25_config], nodes=mock_nodes, embed_model=mock_embedding + ) assert isinstance(retriever, SimpleHybridRetriever) - def test_get_retriever_with_chroma_config(self, mocker, mock_vector_store_index, mock_chroma_vector_store): + def test_get_retriever_with_chroma_config(self, mocker, mock_chroma_vector_store, mock_embedding): mock_config = ChromaRetrieverConfig(persist_path="/path/to/chroma", collection_name="test_collection") mock_chromadb = mocker.patch("metagpt.rag.factories.retriever.chromadb.PersistentClient") mock_chromadb.get_or_create_collection.return_value = mocker.MagicMock() mocker.patch("metagpt.rag.factories.retriever.ChromaVectorStore", return_value=mock_chroma_vector_store) - mocker.patch.object(self.retriever_factory, "_extract_index", return_value=mock_vector_store_index) - retriever = self.retriever_factory.get_retriever(configs=[mock_config]) + retriever = self.retriever_factory.get_retriever(configs=[mock_config], nodes=[], embed_model=mock_embedding) assert isinstance(retriever, ChromaRetriever) - def test_get_retriever_with_es_config(self, mocker, mock_vector_store_index, mock_es_vector_store): + def test_get_retriever_with_es_config(self, mocker, mock_es_vector_store, mock_embedding): mock_config = ElasticsearchRetrieverConfig(store_config=ElasticsearchStoreConfig()) mocker.patch("metagpt.rag.factories.retriever.ElasticsearchStore", return_value=mock_es_vector_store) - mocker.patch.object(self.retriever_factory, "_extract_index", return_value=mock_vector_store_index) - retriever = self.retriever_factory.get_retriever(configs=[mock_config]) + retriever = self.retriever_factory.get_retriever(configs=[mock_config], nodes=[], embed_model=mock_embedding) assert isinstance(retriever, ElasticsearchRetriever)