-
Notifications
You must be signed in to change notification settings - Fork 5.3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #1457 from Jacksonxhx/milvus
Integrated Milvus with MetaGPT
- Loading branch information
Showing
10 changed files
with
261 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,99 @@ | ||
from dataclasses import dataclass | ||
from typing import Any, Dict, List, Optional | ||
|
||
from metagpt.document_store.base_store import BaseStore | ||
|
||
|
||
@dataclass | ||
class MilvusConnection: | ||
""" | ||
Args: | ||
uri: milvus url | ||
token: milvus token | ||
""" | ||
|
||
uri: str = None | ||
token: str = None | ||
|
||
|
||
class MilvusStore(BaseStore): | ||
def __init__(self, connect: MilvusConnection): | ||
try: | ||
from pymilvus import MilvusClient | ||
except ImportError: | ||
raise Exception("Please install pymilvus first.") | ||
if not connect.uri: | ||
raise Exception("please check MilvusConnection, uri must be set.") | ||
self.client = MilvusClient(uri=connect.uri, token=connect.token) | ||
|
||
def create_collection(self, collection_name: str, dim: int, enable_dynamic_schema: bool = True): | ||
from pymilvus import DataType | ||
|
||
if self.client.has_collection(collection_name=collection_name): | ||
self.client.drop_collection(collection_name=collection_name) | ||
|
||
schema = self.client.create_schema( | ||
auto_id=False, | ||
enable_dynamic_field=False, | ||
) | ||
schema.add_field(field_name="id", datatype=DataType.VARCHAR, is_primary=True, max_length=36) | ||
schema.add_field(field_name="vector", datatype=DataType.FLOAT_VECTOR, dim=dim) | ||
|
||
index_params = self.client.prepare_index_params() | ||
index_params.add_index(field_name="vector", index_type="AUTOINDEX", metric_type="COSINE") | ||
|
||
self.client.create_collection( | ||
collection_name=collection_name, | ||
schema=schema, | ||
index_params=index_params, | ||
enable_dynamic_schema=enable_dynamic_schema, | ||
) | ||
|
||
@staticmethod | ||
def build_filter(key, value) -> str: | ||
if isinstance(value, str): | ||
filter_expression = f'{key} == "{value}"' | ||
else: | ||
if isinstance(value, list): | ||
filter_expression = f"{key} in {value}" | ||
else: | ||
filter_expression = f"{key} == {value}" | ||
|
||
return filter_expression | ||
|
||
def search( | ||
self, | ||
collection_name: str, | ||
query: List[float], | ||
filter: Dict = None, | ||
limit: int = 10, | ||
output_fields: Optional[List[str]] = None, | ||
) -> List[dict]: | ||
filter_expression = " and ".join([self.build_filter(key, value) for key, value in filter.items()]) | ||
print(filter_expression) | ||
|
||
res = self.client.search( | ||
collection_name=collection_name, | ||
data=[query], | ||
filter=filter_expression, | ||
limit=limit, | ||
output_fields=output_fields, | ||
)[0] | ||
|
||
return res | ||
|
||
def add(self, collection_name: str, _ids: List[str], vector: List[List[float]], metadata: List[Dict[str, Any]]): | ||
data = dict() | ||
|
||
for i, id in enumerate(_ids): | ||
data["id"] = id | ||
data["vector"] = vector[i] | ||
data["metadata"] = metadata[i] | ||
|
||
self.client.upsert(collection_name=collection_name, data=data) | ||
|
||
def delete(self, collection_name: str, _ids: List[str]): | ||
self.client.delete(collection_name=collection_name, ids=_ids) | ||
|
||
def write(self, *args, **kwargs): | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
"""Milvus retriever.""" | ||
|
||
from llama_index.core.retrievers import VectorIndexRetriever | ||
from llama_index.core.schema import BaseNode | ||
|
||
|
||
class MilvusRetriever(VectorIndexRetriever): | ||
"""Milvus retriever.""" | ||
|
||
def add_nodes(self, nodes: list[BaseNode], **kwargs) -> None: | ||
"""Support add nodes.""" | ||
self._index.insert_nodes(nodes, **kwargs) | ||
|
||
def persist(self, persist_dir: str, **kwargs) -> None: | ||
"""Support persist. | ||
Milvus automatically saves, so there is no need to implement.""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
import random | ||
|
||
import pytest | ||
|
||
from metagpt.document_store.milvus_store import MilvusConnection, MilvusStore | ||
|
||
seed_value = 42 | ||
random.seed(seed_value) | ||
|
||
vectors = [[random.random() for _ in range(8)] for _ in range(10)] | ||
ids = [f"doc_{i}" for i in range(10)] | ||
metadata = [{"color": "red", "rand_number": i % 10} for i in range(10)] | ||
|
||
|
||
def assert_almost_equal(actual, expected): | ||
delta = 1e-10 | ||
if isinstance(expected, list): | ||
assert len(actual) == len(expected) | ||
for ac, exp in zip(actual, expected): | ||
assert abs(ac - exp) <= delta, f"{ac} is not within {delta} of {exp}" | ||
else: | ||
assert abs(actual - expected) <= delta, f"{actual} is not within {delta} of {expected}" | ||
|
||
|
||
@pytest.mark.skip() # Skip because the pymilvus dependency is not installed by default | ||
def test_milvus_store(): | ||
milvus_connection = MilvusConnection(uri="./milvus_local.db") | ||
milvus_store = MilvusStore(milvus_connection) | ||
|
||
collection_name = "TestCollection" | ||
milvus_store.create_collection(collection_name, dim=8) | ||
|
||
milvus_store.add(collection_name, ids, vectors, metadata) | ||
|
||
search_results = milvus_store.search(collection_name, query=[1.0] * 8) | ||
assert len(search_results) > 0 | ||
first_result = search_results[0] | ||
assert first_result["id"] == "doc_0" | ||
|
||
search_results_with_filter = milvus_store.search(collection_name, query=[1.0] * 8, filter={"rand_number": 1}) | ||
assert len(search_results_with_filter) > 0 | ||
assert search_results_with_filter[0]["id"] == "doc_1" | ||
|
||
milvus_store.delete(collection_name, _ids=["doc_0"]) | ||
deleted_results = milvus_store.search(collection_name, query=[1.0] * 8, limit=1) | ||
assert deleted_results[0]["id"] != "doc_0" | ||
|
||
milvus_store.client.drop_collection(collection_name) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.