Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support MyScale vector database #1

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .github/workflows/api-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ jobs:
- name: Run Workflow
run: poetry run -C api bash dev/pytest/pytest_workflow.sh

- name: Set up Vector Stores (Weaviate, Qdrant, PGVector, Milvus, PgVecto-RS, Chroma)
- name: Set up Vector Stores (Weaviate, Qdrant, PGVector, Milvus, PgVecto-RS, Chroma, MyScale)
uses: hoverkraft-tech/compose-action@v2.0.0
with:
compose-file: |
Expand All @@ -89,5 +89,6 @@ jobs:
pgvecto-rs
pgvector
chroma
myscale
- name: Test Vector Stores
run: poetry run -C api bash dev/pytest/pytest_vdb.sh
10 changes: 9 additions & 1 deletion api/.env.example
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ OCI_REGION=your-region
WEB_API_CORS_ALLOW_ORIGINS=http://127.0.0.1:3000,*
CONSOLE_CORS_ALLOW_ORIGINS=http://127.0.0.1:3000,*

# Vector database configuration, support: weaviate, qdrant, milvus, relyt, pgvecto_rs, pgvector, pgvector, chroma, opensearch, tidb_vector
# Vector database configuration, support: weaviate, qdrant, milvus, myscale, relyt, pgvecto_rs, pgvector, pgvector, chroma, opensearch, tidb_vector
VECTOR_STORE=weaviate

# Weaviate configuration
Expand All @@ -106,6 +106,14 @@ MILVUS_USER=root
MILVUS_PASSWORD=Milvus
MILVUS_SECURE=false

# MyScale configuration
MYSCALE_HOST=127.0.0.1
MYSCALE_PORT=8123
MYSCALE_USER=default
MYSCALE_PASSWORD=
MYSCALE_DATABASE=default
MYSCALE_FTS_PARAMS=

# Relyt configuration
RELYT_HOST=127.0.0.1
RELYT_PORT=5432
Expand Down
2 changes: 2 additions & 0 deletions api/configs/middleware/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from configs.middleware.vdb.analyticdb_config import AnalyticdbConfig
from configs.middleware.vdb.chroma_config import ChromaConfig
from configs.middleware.vdb.milvus_config import MilvusConfig
from configs.middleware.vdb.myscale_config import MyScaleConfig
from configs.middleware.vdb.opensearch_config import OpenSearchConfig
from configs.middleware.vdb.oracle_config import OracleConfig
from configs.middleware.vdb.pgvector_config import PGVectorConfig
Expand Down Expand Up @@ -187,6 +188,7 @@ class MiddlewareConfig(
AnalyticdbConfig,
ChromaConfig,
MilvusConfig,
MyScaleConfig,
OpenSearchConfig,
OracleConfig,
PGVectorConfig,
Expand Down
39 changes: 39 additions & 0 deletions api/configs/middleware/vdb/myscale_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
from typing import Optional

from pydantic import BaseModel, Field, PositiveInt


class MyScaleConfig(BaseModel):
"""
MyScale configs
"""

MYSCALE_HOST: Optional[str] = Field(
description='MyScale host',
default=None,
)

MYSCALE_PORT: Optional[PositiveInt] = Field(
description='MyScale port',
default=8123,
)

MYSCALE_USER: Optional[str] = Field(
description='MyScale user',
default=None,
)

MYSCALE_PASSWORD: Optional[str] = Field(
description='MyScale password',
default=None,
)

MYSCALE_DATABASE: Optional[str] = Field(
description='MyScale database name',
default=None,
)

MYSCALE_FTS_PARAMS: Optional[str] = Field(
description='MyScale fts index parameters',
default=None,
)
4 changes: 2 additions & 2 deletions api/controllers/console/datasets/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -548,7 +548,7 @@ def get(self):
RetrievalMethod.SEMANTIC_SEARCH
]
}
case VectorType.QDRANT | VectorType.WEAVIATE | VectorType.OPENSEARCH | VectorType.ANALYTICDB:
case VectorType.QDRANT | VectorType.WEAVIATE | VectorType.OPENSEARCH | VectorType.ANALYTICDB | VectorType.MYSCALE:
return {
'retrieval_method': [
RetrievalMethod.SEMANTIC_SEARCH,
Expand All @@ -572,7 +572,7 @@ def get(self, vector_type):
RetrievalMethod.SEMANTIC_SEARCH
]
}
case VectorType.QDRANT | VectorType.WEAVIATE | VectorType.OPENSEARCH| VectorType.ANALYTICDB:
case VectorType.QDRANT | VectorType.WEAVIATE | VectorType.OPENSEARCH| VectorType.ANALYTICDB | VectorType.MYSCALE:
return {
'retrieval_method': [
RetrievalMethod.SEMANTIC_SEARCH,
Expand Down
Empty file.
170 changes: 170 additions & 0 deletions api/core/rag/datasource/vdb/myscale/myscale_vector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
import json
import logging
import uuid
from enum import Enum
from typing import Any

from clickhouse_connect import get_client
from flask import current_app
from pydantic import BaseModel

from core.rag.datasource.entity.embedding import Embeddings
from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.models.document import Document
from models.dataset import Dataset


class MyScaleConfig(BaseModel):
host: str
port: int
user: str
password: str
database: str
fts_params: str


class SortOrder(Enum):
ASC = "ASC"
DESC = "DESC"


class MyScaleVector(BaseVector):

def __init__(self, collection_name: str, config: MyScaleConfig, metric: str = "Cosine"):
super().__init__(collection_name)
self._config = config
self._metric = metric
self._vec_order = SortOrder.ASC if metric.upper() in ["COSINE", "L2"] else SortOrder.DESC
self._client = get_client(
host=config.host,
port=config.port,
username=config.user,
password=config.password,
)
self._client.command("SET allow_experimental_object_type=1")

def get_type(self) -> str:
return VectorType.MYSCALE

def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
dimension = len(embeddings[0])
self._create_collection(dimension)
return self.add_texts(documents=texts, embeddings=embeddings, **kwargs)

def _create_collection(self, dimension: int):
logging.info(f"create MyScale collection {self._collection_name} with dimension {dimension}")
self._client.command(f"CREATE DATABASE IF NOT EXISTS {self._config.database}")
fts_params = f"('{self._config.fts_params}')" if self._config.fts_params else ""
sql = f"""
CREATE TABLE IF NOT EXISTS {self._config.database}.{self._collection_name}(
id String,
text String,
vector Array(Float32),
metadata JSON,
CONSTRAINT cons_vec_len CHECK length(vector) = {dimension},
VECTOR INDEX vidx vector TYPE DEFAULT('metric_type = {self._metric}'),
INDEX text_idx text TYPE fts{fts_params}
) ENGINE = MergeTree ORDER BY id
"""
self._client.command(sql)

def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
ids = []
columns = ["id", "text", "vector", "metadata"]
values = []
for i, doc in enumerate(documents):
doc_id = doc.metadata.get("doc_id", str(uuid.uuid4()))
row = (
doc_id,
self.escape_str(doc.page_content),
embeddings[i],
json.dumps(doc.metadata) if doc.metadata else {}
)
values.append(str(row))
ids.append(doc_id)
sql = f"""
INSERT INTO {self._config.database}.{self._collection_name}
({",".join(columns)}) VALUES {",".join(values)}
"""
self._client.command(sql)
return ids

@staticmethod
def escape_str(value: Any) -> str:
return "".join(f"\\{c}" if c in ("\\", "'") else c for c in str(value))

def text_exists(self, id: str) -> bool:
results = self._client.query(f"SELECT id FROM {self._config.database}.{self._collection_name} WHERE id='{id}'")
return results.row_count > 0

def delete_by_ids(self, ids: list[str]) -> None:
self._client.command(
f"DELETE FROM {self._config.database}.{self._collection_name} WHERE id IN {str(tuple(ids))}")

def get_ids_by_metadata_field(self, key: str, value: str):
rows = self._client.query(
f"SELECT DISTINCT id FROM {self._config.database}.{self._collection_name} WHERE metadata.{key}='{value}'"
).result_rows
return [row[0] for row in rows]

def delete_by_metadata_field(self, key: str, value: str) -> None:
self._client.command(
f"DELETE FROM {self._config.database}.{self._collection_name} WHERE metadata.{key}='{value}'"
)

def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
return self._search(f"distance(vector, {str(query_vector)})", self._vec_order, **kwargs)

def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
return self._search(f"TextSearch(text, '{query}')", SortOrder.DESC, **kwargs)

def _search(self, dist: str, order: SortOrder, **kwargs: Any) -> list[Document]:
top_k = kwargs.get("top_k", 5)
score_threshold = kwargs.get("score_threshold", 0.0)
where_str = f"WHERE dist < {1 - score_threshold}" if \
self._metric.upper() == "COSINE" and order == SortOrder.ASC and score_threshold > 0.0 else ""
sql = f"""
SELECT text, metadata, {dist} as dist FROM {self._config.database}.{self._collection_name}
{where_str} ORDER BY dist {order.value} LIMIT {top_k}
"""
try:
return [
Document(
page_content=r["text"],
metadata=r["metadata"],
)
for r in self._client.query(sql).named_results()
]
except Exception as e:
logging.error(f"\033[91m\033[1m{type(e)}\033[0m \033[95m{str(e)}\033[0m")
return []

def delete(self) -> None:
self._client.command(f"DROP TABLE IF EXISTS {self._config.database}.{self._collection_name}")


class MyScaleVectorFactory(AbstractVectorFactory):
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> MyScaleVector:
if dataset.index_struct_dict:
class_prefix: str = dataset.index_struct_dict['vector_store']['class_prefix']
collection_name = class_prefix.lower()
else:
dataset_id = dataset.id
collection_name = Dataset.gen_collection_name_by_id(dataset_id).lower()
dataset.index_struct = json.dumps(
self.gen_index_struct_dict(VectorType.MYSCALE, collection_name))

config = current_app.config
return MyScaleVector(
collection_name=collection_name,
config=MyScaleConfig(
host=config.get("MYSCALE_HOST", "localhost"),
port=int(config.get("MYSCALE_PORT", 8123)),
user=config.get("MYSCALE_USER", "default"),
password=config.get("MYSCALE_PASSWORD", ""),
database=config.get("MYSCALE_DATABASE", "default"),
fts_params=config.get("MYSCALE_FTS_PARAMS", ""),
),
)
3 changes: 3 additions & 0 deletions api/core/rag/datasource/vdb/vector_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,9 @@ def get_vector_factory(vector_type: str) -> type[AbstractVectorFactory]:
case VectorType.MILVUS:
from core.rag.datasource.vdb.milvus.milvus_vector import MilvusVectorFactory
return MilvusVectorFactory
case VectorType.MYSCALE:
from core.rag.datasource.vdb.myscale.myscale_vector import MyScaleVectorFactory
return MyScaleVectorFactory
case VectorType.PGVECTOR:
from core.rag.datasource.vdb.pgvector.pgvector import PGVectorFactory
return PGVectorFactory
Expand Down
1 change: 1 addition & 0 deletions api/core/rag/datasource/vdb/vector_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ class VectorType(str, Enum):
ANALYTICDB = 'analyticdb'
CHROMA = 'chroma'
MILVUS = 'milvus'
MYSCALE = 'myscale'
PGVECTOR = 'pgvector'
PGVECTO_RS = 'pgvecto-rs'
QDRANT = 'qdrant'
Expand Down
Loading