diff --git a/autollm/auto/vector_store_index.py b/autollm/auto/vector_store_index.py index 28b3ab96..9ca267b1 100644 --- a/autollm/auto/vector_store_index.py +++ b/autollm/auto/vector_store_index.py @@ -6,6 +6,7 @@ from llama_index.schema import BaseNode from autollm.utils.env_utils import on_rm_error +from autollm.utils.lancedb_vectorstore import LanceDBVectorStore from autollm.utils.logging import logger @@ -32,6 +33,8 @@ def from_defaults( vector_store_type: str = "LanceDBVectorStore", lancedb_uri: str = None, lancedb_table_name: str = "vectors", + lancedb_api_key: Optional[str] = None, + lancedb_region: Optional[str] = None, documents: Optional[Sequence[Document]] = None, nodes: Optional[Sequence[BaseNode]] = None, service_context: Optional[ServiceContext] = None, @@ -75,7 +78,15 @@ def from_defaults( exist_ok=exist_ok, overwrite_existing=overwrite_existing) + vector_store = LanceDBVectorStore( + uri=lancedb_uri, + table_name=lancedb_table_name, + api_key=lancedb_api_key, + region=lancedb_region, + **kwargs) + vector_store = VectorStoreClass(uri=lancedb_uri, table_name=lancedb_table_name, **kwargs) + else: vector_store = VectorStoreClass(**kwargs) diff --git a/autollm/utils/lancedb_vectorstore.py b/autollm/utils/lancedb_vectorstore.py new file mode 100644 index 00000000..13996733 --- /dev/null +++ b/autollm/utils/lancedb_vectorstore.py @@ -0,0 +1,45 @@ +"""LanceDB vector store with cloud storage support.""" +import os +from typing import Any, Optional + +from dotenv import load_dotenv +from llama_index.vector_stores import LanceDBVectorStore as LanceDBVectorStoreBase + +load_dotenv() + + +class LanceDBVectorStore(LanceDBVectorStoreBase): + + def __init__( + self, + uri: str, + table_name: str = "vectors", + nprobes: int = 20, + refine_factor: Optional[int] = None, + api_key: Optional[str] = None, + region: Optional[str] = None, + **kwargs: Any, + ) -> None: + import_err_msg = "`lancedb` package not found, please run `pip install lancedb`" + try: + import lancedb + except ImportError: + raise ImportError(import_err_msg) + + # Check for API key and region in environment variables if not provided + if api_key is None: + api_key = os.getenv('LANCEDB_API_KEY') + if region is None: + region = os.getenv('LANCEDB_REGION') + + if api_key and region: + self.connection = lancedb.connect(uri, api_key=api_key, region=region) + else: + self.connection = lancedb.connect(uri) + + self.uri = uri + self.table_name = table_name + self.nprobes = nprobes + self.refine_factor = refine_factor + self.api_key = api_key + self.region = region diff --git a/requirements.txt b/requirements.txt index 293458fb..1990d3b4 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,4 +4,4 @@ uvicorn fastapi python-dotenv httpx -lancedb==0.3.3 +lancedb==0.3.4