Skip to content

Commit

Permalink
fix redundant embedding
Browse files Browse the repository at this point in the history
  • Loading branch information
seehi committed Apr 24, 2024
1 parent e976ece commit b8b1a66
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 3 deletions.
4 changes: 3 additions & 1 deletion metagpt/rag/engines/simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,9 @@ def _from_nodes(
embed_model = cls._resolve_embed_model(embed_model, retriever_configs)
llm = llm or get_rag_llm()

retriever = get_retriever(configs=retriever_configs, nodes=nodes, embed_model=embed_model)
retriever = get_retriever(
configs=retriever_configs, nodes=nodes, embed_model=embed_model
) # Default VectorStoreIndex(nodes, embed_model).as_retriever
rankers = get_rankers(configs=ranker_configs, llm=llm) # Default []

return cls(
Expand Down
26 changes: 24 additions & 2 deletions metagpt/rag/factories/retriever.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
"""RAG Retriever Factory."""


from functools import wraps

import chromadb
import faiss
from llama_index.core import StorageContext, VectorStoreIndex
Expand Down Expand Up @@ -28,6 +30,22 @@
)


def get_or_build_index(build_index_func):
"""Find index using `_extract_index` method.
If no index is found, using build_index_func.
"""

@wraps(build_index_func)
def wrapper(self, config, **kwargs):
index = self._extract_index(config, **kwargs)
if index is not None:
return index
return build_index_func(self, config, **kwargs)

return wrapper


class RetrieverFactory(ConfigBasedFactory):
"""Modify creators for dynamically instance implementation."""

Expand Down Expand Up @@ -59,12 +77,13 @@ def _create_default(self, **kwargs) -> RAGRetriever:
return index.as_retriever()

def _create_faiss_retriever(self, config: FAISSRetrieverConfig, **kwargs) -> FAISSRetriever:
config.index = self._extract_index(config, **kwargs) or self._build_faiss_index(config, **kwargs)
config.index = self._build_faiss_index(config, **kwargs)

return FAISSRetriever(**config.model_dump())

def _create_bm25_retriever(self, config: BM25RetrieverConfig, **kwargs) -> DynamicBM25Retriever:
nodes = self._extract_nodes(config, **kwargs)
index = self._extract_index(config, **kwargs)
nodes = list(index.docstore.docs.values()) if index else self._extract_nodes(config, **kwargs)

return DynamicBM25Retriever(nodes=nodes, **config.model_dump())

Expand Down Expand Up @@ -95,18 +114,21 @@ def _build_default_index(self, **kwargs) -> VectorStoreIndex:

return index

@get_or_build_index
def _build_faiss_index(self, config: FAISSRetrieverConfig, **kwargs) -> VectorStoreIndex:
vector_store = FaissVectorStore(faiss_index=faiss.IndexFlatL2(config.dimensions))

return self._build_index_from_vector_store(config, vector_store, **kwargs)

@get_or_build_index
def _build_chroma_index(self, config: ChromaRetrieverConfig, **kwargs) -> VectorStoreIndex:
db = chromadb.PersistentClient(path=str(config.persist_path))
chroma_collection = db.get_or_create_collection(config.collection_name, metadata=config.metadata)
vector_store = ChromaVectorStore(chroma_collection=chroma_collection)

return self._build_index_from_vector_store(config, vector_store, **kwargs)

@get_or_build_index
def _build_es_index(self, config: ElasticsearchRetrieverConfig, **kwargs) -> VectorStoreIndex:
vector_store = ElasticsearchStore(**config.store_config.model_dump())

Expand Down
16 changes: 16 additions & 0 deletions tests/metagpt/rag/factories/test_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,3 +119,19 @@ def test_extract_index_from_kwargs(self, mock_vector_store_index):
extracted_index = self.retriever_factory._extract_index(index=mock_vector_store_index)

assert extracted_index == mock_vector_store_index

def test_get_or_build_when_get(self, mocker):
want = "existing_index"
mocker.patch.object(self.retriever_factory, "_extract_index", return_value=want)

got = self.retriever_factory._build_es_index(None)

assert got == want

def test_get_or_build_when_build(self, mocker):
want = "call_build_es_index"
mocker.patch.object(self.retriever_factory, "_build_es_index", return_value=want)

got = self.retriever_factory._build_es_index(None)

assert got == want

0 comments on commit b8b1a66

Please sign in to comment.