Skip to content

Commit

Permalink
fix:.env customize vector store config does not work
Browse files Browse the repository at this point in the history
Close #655
  • Loading branch information
Aries-ckt committed Oct 8, 2023
1 parent f2427b1 commit 0ff63fe
Show file tree
Hide file tree
Showing 8 changed files with 22 additions and 20 deletions.
8 changes: 8 additions & 0 deletions .env.template
Original file line number Diff line number Diff line change
Expand Up @@ -112,13 +112,21 @@ DENYLISTED_PLUGINS=
#*******************************************************************#
#** VECTOR STORE SETTINGS **#
#*******************************************************************#
### Chroma vector db config
VECTOR_STORE_TYPE=Chroma
#CHROMA_PERSIST_PATH=/root/DB-GPT/pilot/data

### Milvus vector db config
#VECTOR_STORE_TYPE=Milvus
#MILVUS_URL=127.0.0.1
#MILVUS_PORT=19530
#MILVUS_USERNAME
#MILVUS_PASSWORD
#MILVUS_SECURE=

### Weaviate vector db config
#VECTOR_STORE_TYPE=Weaviate
#WEAVIATE_URL=https://kt-region-m8hcy0wc.weaviate.network

#*******************************************************************#
#** WebServer Language Support **#
Expand Down
3 changes: 1 addition & 2 deletions pilot/scene/chat_knowledge/v1/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@ def __init__(self, chat_param: Dict):
vector_store_config = {
"vector_store_name": self.knowledge_space,
"vector_store_type": CFG.VECTOR_STORE_TYPE,
"chroma_persist_path": KNOWLEDGE_UPLOAD_ROOT_PATH,
}
embedding_factory = CFG.SYSTEM_APP.get_component(
"embedding_factory", EmbeddingFactory
Expand Down Expand Up @@ -93,7 +92,7 @@ def generate_input_values(self):
context = [d.page_content for d in docs]
context = context[: self.max_token]
relations = list(
set([os.path.basename(d.metadata.get("source")) for d in docs])
set([os.path.basename(d.metadata.get("source", "")) for d in docs])
)
input_values = {
"context": context,
Expand Down
1 change: 0 additions & 1 deletion pilot/server/knowledge/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,6 @@ def sync_knowledge_document(self, space_name, sync_request: DocumentSyncRequest)
vector_store_config={
"vector_store_name": space_name,
"vector_store_type": CFG.VECTOR_STORE_TYPE,
"chroma_persist_path": KNOWLEDGE_UPLOAD_ROOT_PATH,
},
text_splitter=text_splitter,
embedding_factory=embedding_factory,
Expand Down
8 changes: 0 additions & 8 deletions pilot/summary/db_summary_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@ def db_summary_embedding(self, dbname, db_type):
vector_store_config = {
"vector_store_name": dbname + "_summary",
"vector_store_type": CFG.VECTOR_STORE_TYPE,
"chroma_persist_path": KNOWLEDGE_UPLOAD_ROOT_PATH,
"embeddings": embeddings,
}
embedding = StringEmbedding(
Expand Down Expand Up @@ -73,7 +72,6 @@ def db_summary_embedding(self, dbname, db_type):
table_vector_store_config = {
"vector_store_name": dbname + "_" + table_name + "_ts",
"vector_store_type": CFG.VECTOR_STORE_TYPE,
"chroma_persist_path": KNOWLEDGE_UPLOAD_ROOT_PATH,
"embeddings": embeddings,
}
embedding = StringEmbedding(
Expand All @@ -91,7 +89,6 @@ def get_db_summary(self, dbname, query, topk):
vector_store_config = {
"vector_store_name": dbname + "_profile",
"vector_store_type": CFG.VECTOR_STORE_TYPE,
"chroma_persist_path": KNOWLEDGE_UPLOAD_ROOT_PATH,
}
embedding_factory = CFG.SYSTEM_APP.get_component(
"embedding_factory", EmbeddingFactory
Expand All @@ -112,9 +109,7 @@ def get_similar_tables(self, dbname, query, topk):

vector_store_config = {
"vector_store_name": dbname + "_summary",
"chroma_persist_path": KNOWLEDGE_UPLOAD_ROOT_PATH,
"vector_store_type": CFG.VECTOR_STORE_TYPE,
"chroma_persist_path": KNOWLEDGE_UPLOAD_ROOT_PATH,
}
embedding_factory = CFG.SYSTEM_APP.get_component(
"embedding_factory", EmbeddingFactory
Expand Down Expand Up @@ -142,9 +137,7 @@ def get_similar_tables(self, dbname, query, topk):
for table in related_tables:
vector_store_config = {
"vector_store_name": dbname + "_" + table + "_ts",
"chroma_persist_path": KNOWLEDGE_UPLOAD_ROOT_PATH,
"vector_store_type": CFG.VECTOR_STORE_TYPE,
"chroma_persist_path": KNOWLEDGE_UPLOAD_ROOT_PATH,
}
knowledge_embedding_client = EmbeddingEngine(
model_name=EMBEDDING_MODEL_CONFIG[CFG.EMBEDDING_MODEL],
Expand Down Expand Up @@ -172,7 +165,6 @@ def init_db_profile(self, db_summary_client, dbname, embeddings):
vector_store_name = dbname + "_profile"
profile_store_config = {
"vector_store_name": vector_store_name,
"chroma_persist_path": KNOWLEDGE_UPLOAD_ROOT_PATH,
"vector_store_type": CFG.VECTOR_STORE_TYPE,
"embeddings": embeddings,
}
Expand Down
2 changes: 1 addition & 1 deletion pilot/vector_store/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def similar_search(self, text, topk) -> None:
pass

@abstractmethod
def vector_name_exists(self, text, topk) -> None:
def vector_name_exists(self) -> None:
"""is vector store name exist."""
pass

Expand Down
7 changes: 5 additions & 2 deletions pilot/vector_store/chroma_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,13 @@ def __init__(self, ctx: {}) -> None:
from langchain.vectorstores import Chroma

self.ctx = ctx
self.embeddings = ctx.get("embeddings", None)
chroma_path = ctx.get(
"CHROMA_PERSIST_PATH", os.getenv("CHROMA_PERSIST_PATH", os.getcwd())
)
self.persist_dir = os.path.join(
ctx["chroma_persist_path"], ctx["vector_store_name"] + ".vectordb"
chroma_path, ctx["vector_store_name"] + ".vectordb"
)
self.embeddings = ctx.get("embeddings", None)
chroma_settings = Settings(
# chroma_db_impl="duckdb+parquet", => deprecated configuration of Chroma
persist_directory=self.persist_dir,
Expand Down
11 changes: 6 additions & 5 deletions pilot/vector_store/milvus_store.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations
import logging
import os
from typing import Any, Iterable, List, Optional, Tuple

from pymilvus import Collection, DataType, connections, utility
Expand All @@ -21,12 +22,12 @@ def __init__(self, ctx: {}) -> None:
# self.configure(cfg)

connect_kwargs = {}
self.uri = ctx.get("milvus_url", None)
self.port = ctx.get("milvus_port", None)
self.username = ctx.get("milvus_username", None)
self.password = ctx.get("milvus_password", None)
self.uri = ctx.get("MILVUS_URL", os.getenv("MILVUS_URL"))
self.port = ctx.get("MILVUS_PORT", os.getenv("MILVUS_PORT"))
self.username = ctx.get("MILVUS_USERNAME", os.getenv("MILVUS_USERNAME"))
self.password = ctx.get("MILVUS_PASSWORD", os.getenv("MILVUS_PASSWORD"))
self.secure = ctx.get("MILVUS_SECURE", os.getenv("MILVUS_SECURE"))
self.collection_name = ctx.get("vector_store_name", None)
self.secure = ctx.get("secure", None)
self.embedding = ctx.get("embeddings", None)
self.fields = []
self.alias = "default"
Expand Down
2 changes: 1 addition & 1 deletion pilot/vector_store/weaviate_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def __init__(self, ctx: dict) -> None:
)

self.ctx = ctx
self.weaviate_url = CFG.WEAVIATE_URL
self.weaviate_url = ctx.get("WEAVIATE_URL", os.getenv("WEAVIATE_URL"))
self.embedding = ctx.get("embeddings", None)
self.vector_name = ctx["vector_store_name"]
self.persist_dir = os.path.join(
Expand Down

0 comments on commit 0ff63fe

Please sign in to comment.