From 490203d20f29c6d8ba010c5c289904c98e33457d Mon Sep 17 00:00:00 2001 From: Jacksonxhx Date: Wed, 14 Aug 2024 16:27:30 +0800 Subject: [PATCH 1/5] integrated milvus --- metagpt/document_store/milvus_store.py | 124 ++++++++++++++++++ requirements.txt | 1 + .../document_store/test_milvus_store.py | 48 +++++++ 3 files changed, 173 insertions(+) create mode 100644 metagpt/document_store/milvus_store.py create mode 100644 tests/metagpt/document_store/test_milvus_store.py diff --git a/metagpt/document_store/milvus_store.py b/metagpt/document_store/milvus_store.py new file mode 100644 index 000000000..9d5de93cd --- /dev/null +++ b/metagpt/document_store/milvus_store.py @@ -0,0 +1,124 @@ +from dataclasses import dataclass +from typing import List, Dict, Any, Optional +from pymilvus import MilvusClient, DataType + +from metagpt.document_store.base_store import BaseStore + +@dataclass +class MilvusConnection: + """ + Args: + uri: milvus url + token: milvus token + """ + + uri: str = None + token: str = None + + +class MilvusStore(BaseStore): + def __init__(self, connect: MilvusConnection): + if not connect.uri: + raise Exception("please check MilvusConnection, uri must be set.") + self.client = MilvusClient( + uri=connect.uri, + token=connect.token + ) + + def create_collection( + self, + collection_name: str, + dim: int, + enable_dynamic_schema: bool = True + ): + if self.client.has_collection(collection_name=collection_name): + self.client.drop_collection(collection_name=collection_name) + + schema = self.client.create_schema( + auto_id=False, + enable_dynamic_field=False, + ) + schema.add_field(field_name="id", datatype=DataType.VARCHAR, is_primary=True, max_length=36) + schema.add_field(field_name="vector", datatype=DataType.FLOAT_VECTOR, dim=dim) + + index_params = self.client.prepare_index_params() + index_params.add_index( + field_name="vector", + index_type="AUTOINDEX", + metric_type="COSINE" + ) + + self.client.create_collection( + collection_name=collection_name, + schema=schema, + index_params=index_params, + enable_dynamic_schema=enable_dynamic_schema + ) + + @staticmethod + def build_filter(key, value) -> str: + if isinstance(value, str): + filter_expression = f'{key} == "{value}"' + else: + if isinstance(value, list): + filter_expression = f'{key} in {value}' + else: + filter_expression = f'{key} == {value}' + + return filter_expression + + def search( + self, + collection_name: str, + query: List[float], + filter: Dict[str, str | int | list[int]] = None, + limit: int = 10, + output_fields: Optional[List[str]] = None, + ) -> List[dict]: + filter_expression = '' + + for key, value in filter.items(): + filter_expression += f'{self.build_filter(key, value)} and ' + print(filter_expression) + + res = self.client.search( + collection_name=collection_name, + data=[query], + filter=filter_expression, + limit=limit, + output_fields=output_fields, + )[0] + + return res + + def add( + self, + collection_name: str, + _ids: List[str], + vector: List[List[float]], + metadata: List[Dict[str, Any]] + ): + data = dict() + + for i, id in enumerate(_ids): + data['id'] = id + data['vector'] = vector[i] + data['metadata'] = metadata[i] + + self.client.upsert( + collection_name=collection_name, + data=data + ) + + def delete( + self, + collection_name: str, + _ids: List[str] + ): + self.client.delete( + collection_name=collection_name, + ids=_ids + ) + + def write(self, *args, **kwargs): + pass diff --git a/requirements.txt b/requirements.txt index 8bf0ee399..92f5654da 100644 --- a/requirements.txt +++ b/requirements.txt @@ -79,3 +79,4 @@ gymnasium==0.29.1 boto3~=1.34.69 spark_ai_python~=0.3.30 agentops +pymilvus==2.4.5 diff --git a/tests/metagpt/document_store/test_milvus_store.py b/tests/metagpt/document_store/test_milvus_store.py new file mode 100644 index 000000000..7cfd31381 --- /dev/null +++ b/tests/metagpt/document_store/test_milvus_store.py @@ -0,0 +1,48 @@ +import random +from metagpt.document_store.milvus_store import MilvusConnection, MilvusStore + +seed_value = 42 +random.seed(seed_value) + +vectors = [[random.random() for _ in range(8)] for _ in range(10)] +ids = [f"doc_{i}" for i in range(10)] +metadata = [{"color": "red", "rand_number": i % 10} for i in range(10)] + + +def assert_almost_equal(actual, expected): + delta = 1e-10 + if isinstance(expected, list): + assert len(actual) == len(expected) + for ac, exp in zip(actual, expected): + assert abs(ac - exp) <= delta, f"{ac} is not within {delta} of {exp}" + else: + assert abs(actual - expected) <= delta, f"{actual} is not within {delta} of {expected}" + + +def test_milvus_store(): + milvus_connection = MilvusConnection(uri="./milvus_local.db") + milvus_store = MilvusStore(milvus_connection) + + collection_name = "TestCollection" + milvus_store.create_collection(collection_name, dim=8) + + milvus_store.add(collection_name, ids, vectors, metadata) + + search_results = milvus_store.search(collection_name, query=[1.0] * 8) + assert len(search_results) > 0 + first_result = search_results[0] + assert first_result["id"] == "doc_0" + + search_results_with_filter = milvus_store.search( + collection_name, + query=[1.0] * 8, + filter={"rand_number": 1} + ) + assert len(search_results_with_filter) > 0 + assert search_results_with_filter[0]["id"] == "doc_1" + + milvus_store.delete(collection_name, _ids=["doc_0"]) + deleted_results = milvus_store.search(collection_name, query=[1.0] * 8, limit=1) + assert deleted_results[0]["id"] != "doc_0" + + milvus_store.client.drop_collection(collection_name) From 986fb784aaeedcb2856c5e7bff64a3d11c982dbc Mon Sep 17 00:00:00 2001 From: Jacksonxhx Date: Wed, 14 Aug 2024 18:16:18 +0800 Subject: [PATCH 2/5] integrate milvus --- metagpt/rag/factories/index.py | 8 +++++++ metagpt/rag/factories/retriever.py | 16 +++++++++++++- metagpt/rag/retrievers/milvus_retriever.py | 17 +++++++++++++++ metagpt/rag/schema.py | 21 +++++++++++++++++++ tests/metagpt/rag/factories/test_index.py | 17 ++++++++++++++- tests/metagpt/rag/factories/test_retriever.py | 15 +++++++++++++ 6 files changed, 92 insertions(+), 2 deletions(-) create mode 100644 metagpt/rag/retrievers/milvus_retriever.py diff --git a/metagpt/rag/factories/index.py b/metagpt/rag/factories/index.py index f897af3ad..6da4900a0 100644 --- a/metagpt/rag/factories/index.py +++ b/metagpt/rag/factories/index.py @@ -8,6 +8,7 @@ from llama_index.vector_stores.chroma import ChromaVectorStore from llama_index.vector_stores.elasticsearch import ElasticsearchStore from llama_index.vector_stores.faiss import FaissVectorStore +from llama_index.vector_stores.milvus import MilvusVectorStore from metagpt.rag.factories.base import ConfigBasedFactory from metagpt.rag.schema import ( @@ -17,6 +18,7 @@ ElasticsearchIndexConfig, ElasticsearchKeywordIndexConfig, FAISSIndexConfig, + MilvusIndexConfig, ) @@ -28,6 +30,7 @@ def __init__(self): BM25IndexConfig: self._create_bm25, ElasticsearchIndexConfig: self._create_es, ElasticsearchKeywordIndexConfig: self._create_es, + MilvusIndexConfig: self._create_milvus } super().__init__(creators) @@ -46,6 +49,11 @@ def _create_bm25(self, config: BM25IndexConfig, **kwargs) -> VectorStoreIndex: return self._index_from_storage(storage_context=storage_context, config=config, **kwargs) + def _create_milvus(self, config: MilvusIndexConfig, **kwargs) -> VectorStoreIndex: + vector_store = MilvusVectorStore(collection_name=config.collection_name, uri=config.uri, token=config.token) + + return self._index_from_vector_store(vector_store=vector_store, config=config, **kwargs) + def _create_chroma(self, config: ChromaIndexConfig, **kwargs) -> VectorStoreIndex: db = chromadb.PersistentClient(str(config.persist_path)) chroma_collection = db.get_or_create_collection(config.collection_name, metadata=config.metadata) diff --git a/metagpt/rag/factories/retriever.py b/metagpt/rag/factories/retriever.py index 1460e131b..c3d3a4f80 100644 --- a/metagpt/rag/factories/retriever.py +++ b/metagpt/rag/factories/retriever.py @@ -12,6 +12,7 @@ from llama_index.vector_stores.chroma import ChromaVectorStore from llama_index.vector_stores.elasticsearch import ElasticsearchStore from llama_index.vector_stores.faiss import FaissVectorStore +from llama_index.vector_stores.milvus import MilvusVectorStore from metagpt.rag.factories.base import ConfigBasedFactory from metagpt.rag.retrievers.base import RAGRetriever @@ -20,13 +21,14 @@ from metagpt.rag.retrievers.es_retriever import ElasticsearchRetriever from metagpt.rag.retrievers.faiss_retriever import FAISSRetriever from metagpt.rag.retrievers.hybrid_retriever import SimpleHybridRetriever +from metagpt.rag.retrievers.milvus_retriever import MilvusRetriever from metagpt.rag.schema import ( BaseRetrieverConfig, BM25RetrieverConfig, ChromaRetrieverConfig, ElasticsearchKeywordRetrieverConfig, ElasticsearchRetrieverConfig, - FAISSRetrieverConfig, + FAISSRetrieverConfig, MilvusRetrieverConfig, ) @@ -56,6 +58,7 @@ def __init__(self): ChromaRetrieverConfig: self._create_chroma_retriever, ElasticsearchRetrieverConfig: self._create_es_retriever, ElasticsearchKeywordRetrieverConfig: self._create_es_retriever, + MilvusRetrieverConfig: self._create_milvus_retriever, } super().__init__(creators) @@ -76,6 +79,11 @@ def _create_default(self, **kwargs) -> RAGRetriever: return index.as_retriever() + def _create_milvus_retriever(self, config: MilvusRetrieverConfig, **kwargs) -> MilvusRetriever: + config.index = self._build_milvus_index(config, **kwargs) + + return MilvusRetriever(**config.model_dump()) + def _create_faiss_retriever(self, config: FAISSRetrieverConfig, **kwargs) -> FAISSRetriever: config.index = self._build_faiss_index(config, **kwargs) @@ -128,6 +136,12 @@ def _build_chroma_index(self, config: ChromaRetrieverConfig, **kwargs) -> Vector return self._build_index_from_vector_store(config, vector_store, **kwargs) + @get_or_build_index + def _build_milvus_index(self, config: MilvusRetrieverConfig, **kwargs) -> VectorStoreIndex: + vector_store = MilvusVectorStore(uri=config.uri, collection_name=config.collection_name, token=config.token) + + return self._build_index_from_vector_store(config, vector_store, **kwargs) + @get_or_build_index def _build_es_index(self, config: ElasticsearchRetrieverConfig, **kwargs) -> VectorStoreIndex: vector_store = ElasticsearchStore(**config.store_config.model_dump()) diff --git a/metagpt/rag/retrievers/milvus_retriever.py b/metagpt/rag/retrievers/milvus_retriever.py new file mode 100644 index 000000000..ff2562bd8 --- /dev/null +++ b/metagpt/rag/retrievers/milvus_retriever.py @@ -0,0 +1,17 @@ +"""Milvus retriever.""" + +from llama_index.core.retrievers import VectorIndexRetriever +from llama_index.core.schema import BaseNode + + +class MilvusRetriever(VectorIndexRetriever): + """Milvus retriever.""" + + def add_nodes(self, nodes: list[BaseNode], **kwargs) -> None: + """Support add nodes.""" + self._index.insert_nodes(nodes, **kwargs) + + def persist(self, persist_dir: str, **kwargs) -> None: + """Support persist. + + Milvus automatically saves, so there is no need to implement.""" \ No newline at end of file diff --git a/metagpt/rag/schema.py b/metagpt/rag/schema.py index a8a10f90e..89e189235 100644 --- a/metagpt/rag/schema.py +++ b/metagpt/rag/schema.py @@ -62,6 +62,17 @@ class BM25RetrieverConfig(IndexRetrieverConfig): _no_embedding: bool = PrivateAttr(default=True) +class MilvusRetrieverConfig(IndexRetrieverConfig): + """Config for Milvus-based retrievers.""" + + uri: str = Field(default="./milvus_local.db", description="The directory to save data.") + collection_name: str = Field(default="metagpt", description="The name of the collection.") + token: str = Field(default=None, description="The token for Milvus") + metadata: Optional[CollectionMetadata] = Field( + default=None, description="Optional metadata to associate with the collection" + ) + + class ChromaRetrieverConfig(IndexRetrieverConfig): """Config for Chroma-based retrievers.""" @@ -169,6 +180,16 @@ class ChromaIndexConfig(VectorIndexConfig): default=None, description="Optional metadata to associate with the collection" ) +class MilvusIndexConfig(VectorIndexConfig): + """Config for milvus-based index.""" + + collection_name: str = Field(default="metagpt", description="The name of the collection.") + uri: str = Field(default="./milvus_local.db", description="The uri of the index.") + token: Optional[str] = Field(default=None, description="The token of the index.") + metadata: Optional[CollectionMetadata] = Field( + default=None, description="Optional metadata to associate with the collection" + ) + class BM25IndexConfig(BaseIndexConfig): """Config for bm25-based index.""" diff --git a/tests/metagpt/rag/factories/test_index.py b/tests/metagpt/rag/factories/test_index.py index 9dc5bfb6b..5d8711f9f 100644 --- a/tests/metagpt/rag/factories/test_index.py +++ b/tests/metagpt/rag/factories/test_index.py @@ -7,7 +7,7 @@ ChromaIndexConfig, ElasticsearchIndexConfig, ElasticsearchStoreConfig, - FAISSIndexConfig, + FAISSIndexConfig, MilvusIndexConfig, ) @@ -20,6 +20,10 @@ def setup(self): def faiss_config(self): return FAISSIndexConfig(persist_path="") + @pytest.fixture + def milvus_config(self): + return MilvusIndexConfig(uri="", collection_name="") + @pytest.fixture def chroma_config(self): return ChromaIndexConfig(persist_path="", collection_name="") @@ -65,6 +69,17 @@ def test_create_bm25_index( ): self.index_factory.get_index(bm25_config, embed_model=mock_embedding) + + def test_create_milvus_index(self, mocker, milvus_config, mock_from_vector_store, mock_embedding): + # Mock + mock_milvus_store = mocker.patch("metagpt.rag.factories.index.MilvusVectorStore") + + # Exec + self.index_factory.get_index(milvus_config, embed_model=mock_embedding) + + # Assert + mock_milvus_store.assert_called_once() + def test_create_chroma_index(self, mocker, chroma_config, mock_from_vector_store, mock_embedding): # Mock mock_chroma_db = mocker.patch("metagpt.rag.factories.index.chromadb.PersistentClient") diff --git a/tests/metagpt/rag/factories/test_retriever.py b/tests/metagpt/rag/factories/test_retriever.py index cd55a32db..149e4b172 100644 --- a/tests/metagpt/rag/factories/test_retriever.py +++ b/tests/metagpt/rag/factories/test_retriever.py @@ -5,6 +5,7 @@ from llama_index.core.schema import TextNode from llama_index.vector_stores.chroma import ChromaVectorStore from llama_index.vector_stores.elasticsearch import ElasticsearchStore +from llama_index.vector_stores.milvus import MilvusVectorStore from metagpt.rag.factories.retriever import RetrieverFactory from metagpt.rag.retrievers.bm25_retriever import DynamicBM25Retriever @@ -12,12 +13,14 @@ from metagpt.rag.retrievers.es_retriever import ElasticsearchRetriever from metagpt.rag.retrievers.faiss_retriever import FAISSRetriever from metagpt.rag.retrievers.hybrid_retriever import SimpleHybridRetriever +from metagpt.rag.retrievers.milvus_retriever import MilvusRetriever from metagpt.rag.schema import ( BM25RetrieverConfig, ChromaRetrieverConfig, ElasticsearchRetrieverConfig, ElasticsearchStoreConfig, FAISSRetrieverConfig, + MilvusRetrieverConfig, ) @@ -41,6 +44,10 @@ def mock_vector_store_index(self, mocker): def mock_chroma_vector_store(self, mocker): return mocker.MagicMock(spec=ChromaVectorStore) + @pytest.fixture + def mock_milvus_vector_store(self, mocker): + return mocker.MagicMock(spec=MilvusVectorStore) + @pytest.fixture def mock_es_vector_store(self, mocker): return mocker.MagicMock(spec=ElasticsearchStore) @@ -91,6 +98,14 @@ def test_get_retriever_with_chroma_config(self, mocker, mock_chroma_vector_store assert isinstance(retriever, ChromaRetriever) + def test_get_retriever_with_milvus_config(self, mocker, mock_milvus_vector_store, mock_embedding): + mock_config = MilvusRetrieverConfig(uri="/path/to/milvus", collection_name="test_collection") + mocker.patch("metagpt.rag.factories.retriever.MilvusVectorStore", return_value=mock_milvus_vector_store) + + retriever = self.retriever_factory.get_retriever(configs=[mock_config], nodes=[], embed_model=mock_embedding) + + assert isinstance(retriever, MilvusRetriever) + def test_get_retriever_with_es_config(self, mocker, mock_es_vector_store, mock_embedding): mock_config = ElasticsearchRetrieverConfig(store_config=ElasticsearchStoreConfig()) mocker.patch("metagpt.rag.factories.retriever.ElasticsearchStore", return_value=mock_es_vector_store) From 961146ef3eb2dfd7c7d6c3b7a85e05015f769918 Mon Sep 17 00:00:00 2001 From: Jacksonxhx Date: Tue, 20 Aug 2024 17:19:35 +0800 Subject: [PATCH 3/5] update milvus integration --- metagpt/rag/factories/retriever.py | 5 +++-- metagpt/rag/schema.py | 21 ++++++++++++++++++- tests/metagpt/rag/factories/test_index.py | 1 - tests/metagpt/rag/factories/test_retriever.py | 2 +- 4 files changed, 24 insertions(+), 5 deletions(-) diff --git a/metagpt/rag/factories/retriever.py b/metagpt/rag/factories/retriever.py index c3d3a4f80..3342b8905 100644 --- a/metagpt/rag/factories/retriever.py +++ b/metagpt/rag/factories/retriever.py @@ -28,7 +28,8 @@ ChromaRetrieverConfig, ElasticsearchKeywordRetrieverConfig, ElasticsearchRetrieverConfig, - FAISSRetrieverConfig, MilvusRetrieverConfig, + FAISSRetrieverConfig, + MilvusRetrieverConfig, ) @@ -138,7 +139,7 @@ def _build_chroma_index(self, config: ChromaRetrieverConfig, **kwargs) -> Vector @get_or_build_index def _build_milvus_index(self, config: MilvusRetrieverConfig, **kwargs) -> VectorStoreIndex: - vector_store = MilvusVectorStore(uri=config.uri, collection_name=config.collection_name, token=config.token) + vector_store = MilvusVectorStore(uri=config.uri, collection_name=config.collection_name, token=config.token, dim=config.dimensions) return self._build_index_from_vector_store(config, vector_store, **kwargs) diff --git a/metagpt/rag/schema.py b/metagpt/rag/schema.py index 89e189235..e4d97068d 100644 --- a/metagpt/rag/schema.py +++ b/metagpt/rag/schema.py @@ -8,7 +8,7 @@ from llama_index.core.indices.base import BaseIndex from llama_index.core.schema import TextNode from llama_index.core.vector_stores.types import VectorStoreQueryMode -from pydantic import BaseModel, ConfigDict, Field, PrivateAttr, model_validator +from pydantic import BaseModel, ConfigDict, Field, PrivateAttr, model_validator, validator from metagpt.config2 import config from metagpt.configs.embedding_config import EmbeddingType @@ -71,6 +71,25 @@ class MilvusRetrieverConfig(IndexRetrieverConfig): metadata: Optional[CollectionMetadata] = Field( default=None, description="Optional metadata to associate with the collection" ) + dimensions: int = Field(default=0, description="Dimensionality of the vectors for Milvus index construction.") + + _embedding_type_to_dimensions: ClassVar[dict[EmbeddingType, int]] = { + EmbeddingType.GEMINI: 768, + EmbeddingType.OLLAMA: 4096, + } + + @model_validator(mode="after") + def check_dimensions(self): + if self.dimensions == 0: + self.dimensions = config.embedding.dimensions or self._embedding_type_to_dimensions.get( + config.embedding.api_type, 1536 + ) + if not config.embedding.dimensions and config.embedding.api_type not in self._embedding_type_to_dimensions: + logger.warning( + f"You didn't set dimensions in config when using {config.embedding.api_type}, default to 1536" + ) + + return self class ChromaRetrieverConfig(IndexRetrieverConfig): diff --git a/tests/metagpt/rag/factories/test_index.py b/tests/metagpt/rag/factories/test_index.py index 5d8711f9f..9861e1242 100644 --- a/tests/metagpt/rag/factories/test_index.py +++ b/tests/metagpt/rag/factories/test_index.py @@ -69,7 +69,6 @@ def test_create_bm25_index( ): self.index_factory.get_index(bm25_config, embed_model=mock_embedding) - def test_create_milvus_index(self, mocker, milvus_config, mock_from_vector_store, mock_embedding): # Mock mock_milvus_store = mocker.patch("metagpt.rag.factories.index.MilvusVectorStore") diff --git a/tests/metagpt/rag/factories/test_retriever.py b/tests/metagpt/rag/factories/test_retriever.py index 149e4b172..b808de26e 100644 --- a/tests/metagpt/rag/factories/test_retriever.py +++ b/tests/metagpt/rag/factories/test_retriever.py @@ -99,7 +99,7 @@ def test_get_retriever_with_chroma_config(self, mocker, mock_chroma_vector_store assert isinstance(retriever, ChromaRetriever) def test_get_retriever_with_milvus_config(self, mocker, mock_milvus_vector_store, mock_embedding): - mock_config = MilvusRetrieverConfig(uri="/path/to/milvus", collection_name="test_collection") + mock_config = MilvusRetrieverConfig(uri="/path/to/milvus.db", collection_name="test_collection") mocker.patch("metagpt.rag.factories.retriever.MilvusVectorStore", return_value=mock_milvus_vector_store) retriever = self.retriever_factory.get_retriever(configs=[mock_config], nodes=[], embed_model=mock_embedding) From e5f037f86de9cee983f2b4b7e3d99db1253cfbdb Mon Sep 17 00:00:00 2001 From: ChengZi Date: Tue, 24 Sep 2024 20:37:15 +0800 Subject: [PATCH 4/5] update dependency Signed-off-by: ChengZi --- requirements.txt | 5 ++--- setup.py | 1 + 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/requirements.txt b/requirements.txt index 92f5654da..b4f3f563d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -19,7 +19,7 @@ beautifulsoup4==4.12.3 pandas==2.1.1 pydantic>=2.5.3 #pygame==2.1.3 -#pymilvus==2.2.8 +# pymilvus==2.4.6 # pytest==7.2.2 # test extras require python_docx==0.8.11 PyYAML==6.0.1 @@ -78,5 +78,4 @@ volcengine-python-sdk[ark]~=1.0.94 gymnasium==0.29.1 boto3~=1.34.69 spark_ai_python~=0.3.30 -agentops -pymilvus==2.4.5 +agentops \ No newline at end of file diff --git a/setup.py b/setup.py index 8ba4c8a72..f1dbc113d 100644 --- a/setup.py +++ b/setup.py @@ -43,6 +43,7 @@ def run(self): "llama-index-postprocessor-cohere-rerank==0.1.4", "llama-index-postprocessor-colbert-rerank==0.1.1", "llama-index-postprocessor-flag-embedding-reranker==0.1.2", + # "llama-index-vector-stores-milvus==0.1.23", "docx2txt==0.8", ], } From 4d92fdcec97f3f063d46e14e8e9a1329695fb3ad Mon Sep 17 00:00:00 2001 From: ChengZi Date: Wed, 25 Sep 2024 11:47:28 +0800 Subject: [PATCH 5/5] lazy dependency for milvus Signed-off-by: ChengZi --- metagpt/document_store/milvus_store.py | 71 ++++++------------- .../document_store/test_milvus_store.py | 10 +-- 2 files changed, 28 insertions(+), 53 deletions(-) diff --git a/metagpt/document_store/milvus_store.py b/metagpt/document_store/milvus_store.py index 9d5de93cd..e4d6d985e 100644 --- a/metagpt/document_store/milvus_store.py +++ b/metagpt/document_store/milvus_store.py @@ -1,9 +1,9 @@ from dataclasses import dataclass -from typing import List, Dict, Any, Optional -from pymilvus import MilvusClient, DataType +from typing import Any, Dict, List, Optional from metagpt.document_store.base_store import BaseStore + @dataclass class MilvusConnection: """ @@ -18,19 +18,17 @@ class MilvusConnection: class MilvusStore(BaseStore): def __init__(self, connect: MilvusConnection): + try: + from pymilvus import MilvusClient + except ImportError: + raise Exception("Please install pymilvus first.") if not connect.uri: raise Exception("please check MilvusConnection, uri must be set.") - self.client = MilvusClient( - uri=connect.uri, - token=connect.token - ) + self.client = MilvusClient(uri=connect.uri, token=connect.token) + + def create_collection(self, collection_name: str, dim: int, enable_dynamic_schema: bool = True): + from pymilvus import DataType - def create_collection( - self, - collection_name: str, - dim: int, - enable_dynamic_schema: bool = True - ): if self.client.has_collection(collection_name=collection_name): self.client.drop_collection(collection_name=collection_name) @@ -42,17 +40,13 @@ def create_collection( schema.add_field(field_name="vector", datatype=DataType.FLOAT_VECTOR, dim=dim) index_params = self.client.prepare_index_params() - index_params.add_index( - field_name="vector", - index_type="AUTOINDEX", - metric_type="COSINE" - ) + index_params.add_index(field_name="vector", index_type="AUTOINDEX", metric_type="COSINE") self.client.create_collection( collection_name=collection_name, schema=schema, index_params=index_params, - enable_dynamic_schema=enable_dynamic_schema + enable_dynamic_schema=enable_dynamic_schema, ) @staticmethod @@ -61,9 +55,9 @@ def build_filter(key, value) -> str: filter_expression = f'{key} == "{value}"' else: if isinstance(value, list): - filter_expression = f'{key} in {value}' + filter_expression = f"{key} in {value}" else: - filter_expression = f'{key} == {value}' + filter_expression = f"{key} == {value}" return filter_expression @@ -71,14 +65,11 @@ def search( self, collection_name: str, query: List[float], - filter: Dict[str, str | int | list[int]] = None, + filter: Dict = None, limit: int = 10, output_fields: Optional[List[str]] = None, ) -> List[dict]: - filter_expression = '' - - for key, value in filter.items(): - filter_expression += f'{self.build_filter(key, value)} and ' + filter_expression = " and ".join([self.build_filter(key, value) for key, value in filter.items()]) print(filter_expression) res = self.client.search( @@ -91,34 +82,18 @@ def search( return res - def add( - self, - collection_name: str, - _ids: List[str], - vector: List[List[float]], - metadata: List[Dict[str, Any]] - ): + def add(self, collection_name: str, _ids: List[str], vector: List[List[float]], metadata: List[Dict[str, Any]]): data = dict() for i, id in enumerate(_ids): - data['id'] = id - data['vector'] = vector[i] - data['metadata'] = metadata[i] + data["id"] = id + data["vector"] = vector[i] + data["metadata"] = metadata[i] - self.client.upsert( - collection_name=collection_name, - data=data - ) + self.client.upsert(collection_name=collection_name, data=data) - def delete( - self, - collection_name: str, - _ids: List[str] - ): - self.client.delete( - collection_name=collection_name, - ids=_ids - ) + def delete(self, collection_name: str, _ids: List[str]): + self.client.delete(collection_name=collection_name, ids=_ids) def write(self, *args, **kwargs): pass diff --git a/tests/metagpt/document_store/test_milvus_store.py b/tests/metagpt/document_store/test_milvus_store.py index 7cfd31381..93d4187f9 100644 --- a/tests/metagpt/document_store/test_milvus_store.py +++ b/tests/metagpt/document_store/test_milvus_store.py @@ -1,4 +1,7 @@ import random + +import pytest + from metagpt.document_store.milvus_store import MilvusConnection, MilvusStore seed_value = 42 @@ -19,6 +22,7 @@ def assert_almost_equal(actual, expected): assert abs(actual - expected) <= delta, f"{actual} is not within {delta} of {expected}" +@pytest.mark.skip() # Skip because the pymilvus dependency is not installed by default def test_milvus_store(): milvus_connection = MilvusConnection(uri="./milvus_local.db") milvus_store = MilvusStore(milvus_connection) @@ -33,11 +37,7 @@ def test_milvus_store(): first_result = search_results[0] assert first_result["id"] == "doc_0" - search_results_with_filter = milvus_store.search( - collection_name, - query=[1.0] * 8, - filter={"rand_number": 1} - ) + search_results_with_filter = milvus_store.search(collection_name, query=[1.0] * 8, filter={"rand_number": 1}) assert len(search_results_with_filter) > 0 assert search_results_with_filter[0]["id"] == "doc_1"