Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🚀 feat: Add Atlas MongoDB as an option for Vector Store #21

Merged
merged 16 commits into from
May 11, 2024
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,38 @@ The following environment variables are required to run the application:

Make sure to set these environment variables before running the application. You can set them in a `.env` file or as system environment variables.

### Use Atlas MongoDB as Vector Database
danny-avila marked this conversation as resolved.
Show resolved Hide resolved

Instead of using the default pgvector, we could use [Atlas MongoDB](https://www.mongodb.com/products/platform/atlas-vector-search) as the vector database. To do so, set the following environment variables

```env
VECTOR_DB_TYPE=atlas-mongo
ATLAS_MONGO_DB_URI=<mongodb+srv://...>
MONGO_VECTOR_COLLECTION=<collection name>
```

In additional, make sure the collection defined by `$MONGO_VECTOR_COLLECTION` has the following vector search index created

```json
{
"fields": [
{
"numDimensions": 1536,
"path": "embedding",
"similarity": "cosine",
"type": "vector"
},
{
"path": "file_id",
"type": "filter"
}
]
}
```

Follw one of the [four documented methods](https://www.mongodb.com/docs/atlas/atlas-vector-search/create-index/#procedure) to create the vector index.


### Cloud Installation Settings:

#### AWS:
Expand Down
92 changes: 61 additions & 31 deletions config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,11 @@
from datetime import datetime

from dotenv import find_dotenv, load_dotenv
from langchain_community.embeddings import HuggingFaceEmbeddings, \
HuggingFaceHubEmbeddings, OllamaEmbeddings
from langchain_community.embeddings import (
HuggingFaceEmbeddings,
HuggingFaceHubEmbeddings,
OllamaEmbeddings,
)
from langchain_openai import AzureOpenAIEmbeddings, OpenAIEmbeddings
from starlette.middleware.base import BaseHTTPMiddleware

Expand All @@ -15,27 +18,37 @@
load_dotenv(find_dotenv())


def get_env_variable(var_name: str, default_value: str = None, required: bool = False) -> str:
def get_env_variable(
var_name: str, default_value: str = None, required: bool = False
) -> str:
value = os.getenv(var_name)
if value is None:
if default_value is None and required:
raise ValueError(f"Environment variable '{var_name}' not found.")
return default_value
return value

RAG_HOST = os.getenv('RAG_HOST', '0.0.0.0')
RAG_PORT = int(os.getenv('RAG_PORT', 8000))

RAG_HOST = os.getenv("RAG_HOST", "0.0.0.0")
RAG_PORT = int(os.getenv("RAG_PORT", 8000))

RAG_UPLOAD_DIR = get_env_variable("RAG_UPLOAD_DIR", "./uploads/")
if not os.path.exists(RAG_UPLOAD_DIR):
os.makedirs(RAG_UPLOAD_DIR, exist_ok=True)

VECTOR_DB_TYPE = get_env_variable("VECTOR_DB_TYPE", "pgvector")
POSTGRES_DB = get_env_variable("POSTGRES_DB", "mydatabase")
POSTGRES_USER = get_env_variable("POSTGRES_USER", "myuser")
POSTGRES_PASSWORD = get_env_variable("POSTGRES_PASSWORD", "mypassword")
DB_HOST = get_env_variable("DB_HOST", "db")
DB_PORT = get_env_variable("DB_PORT", "5432")
COLLECTION_NAME = get_env_variable("COLLECTION_NAME", "testcollection")
ATLAS_MONGO_DB_URI = get_env_variable(
"ATLAS_MONGO_DB_URI", "mongodb://127.0.0.1:27018/LibreChat"
)
MONGO_VECTOR_COLLECTION = get_env_variable(
"MONGO_VECTOR_COLLECTION", "vector_collection"
)

CHUNK_SIZE = int(get_env_variable("CHUNK_SIZE", "1500"))
CHUNK_OVERLAP = int(get_env_variable("CHUNK_OVERLAP", "100"))
Expand All @@ -62,6 +75,7 @@ def get_env_variable(var_name: str, default_value: str = None, required: bool =
logger.setLevel(logging.INFO)

if console_json:

class JsonFormatter(logging.Formatter):
def __init__(self):
super(JsonFormatter, self).__init__()
Expand Down Expand Up @@ -96,7 +110,8 @@ def format(self, record):
formatter = JsonFormatter()
else:
formatter = logging.Formatter(
'%(asctime)s - %(name)s - %(levelname)s - %(message)s')
"%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)

handler = logging.StreamHandler() # or logging.FileHandler("app.log")
handler.setFormatter(formatter)
Expand All @@ -113,12 +128,11 @@ async def dispatch(self, request, call_next):
logger_method = logger.debug

logger_method(
f"Request {request.method} {request.url} - {response.status_code}",
extra={
HTTP_REQ: {"method": request.method,
"url": str(request.url)},
HTTP_RES: {"status_code": response.status_code},
},
f"Request {request.method} {request.url} - {response.status_code}",
extra={
HTTP_REQ: {"method": request.method, "url": str(request.url)},
HTTP_RES: {"status_code": response.status_code},
},
)

return response
Expand All @@ -135,33 +149,38 @@ async def dispatch(self, request, call_next):
RAG_OPENAI_PROXY = get_env_variable("RAG_OPENAI_PROXY", None)
AZURE_OPENAI_API_KEY = get_env_variable("AZURE_OPENAI_API_KEY", "")
RAG_AZURE_OPENAI_API_VERSION = get_env_variable("RAG_AZURE_OPENAI_API_VERSION", None)
RAG_AZURE_OPENAI_API_KEY = get_env_variable("RAG_AZURE_OPENAI_API_KEY", AZURE_OPENAI_API_KEY)
RAG_AZURE_OPENAI_API_KEY = get_env_variable(
"RAG_AZURE_OPENAI_API_KEY", AZURE_OPENAI_API_KEY
)
AZURE_OPENAI_ENDPOINT = get_env_variable("AZURE_OPENAI_ENDPOINT", "")
RAG_AZURE_OPENAI_ENDPOINT = get_env_variable("RAG_AZURE_OPENAI_ENDPOINT", AZURE_OPENAI_ENDPOINT).rstrip("/")
RAG_AZURE_OPENAI_ENDPOINT = get_env_variable(
"RAG_AZURE_OPENAI_ENDPOINT", AZURE_OPENAI_ENDPOINT
).rstrip("/")
HF_TOKEN = get_env_variable("HF_TOKEN", "")
OLLAMA_BASE_URL = get_env_variable("OLLAMA_BASE_URL", "http://ollama:11434")


## Embeddings


def init_embeddings(provider, model):
if provider == "openai":
return OpenAIEmbeddings(
model=model,
api_key=RAG_OPENAI_API_KEY,
openai_api_base=RAG_OPENAI_BASEURL,
openai_proxy=RAG_OPENAI_PROXY
openai_proxy=RAG_OPENAI_PROXY,
)
elif provider == "azure":
return AzureOpenAIEmbeddings(
azure_deployment=model,
api_key=RAG_AZURE_OPENAI_API_KEY,
azure_endpoint=RAG_AZURE_OPENAI_ENDPOINT,
api_version=RAG_AZURE_OPENAI_API_VERSION
api_version=RAG_AZURE_OPENAI_API_VERSION,
)
elif provider == "huggingface":
return HuggingFaceEmbeddings(model_name=model, encode_kwargs={
'normalize_embeddings': True})
return HuggingFaceEmbeddings(
model_name=model, encode_kwargs={"normalize_embeddings": True}
)
elif provider == "huggingfacetei":
return HuggingFaceHubEmbeddings(model=model)
elif provider == "ollama":
Expand All @@ -173,20 +192,20 @@ def init_embeddings(provider, model):
EMBEDDINGS_PROVIDER = get_env_variable("EMBEDDINGS_PROVIDER", "openai").lower()

if EMBEDDINGS_PROVIDER == "openai":
EMBEDDINGS_MODEL = get_env_variable("EMBEDDINGS_MODEL",
"text-embedding-3-small")
EMBEDDINGS_MODEL = get_env_variable("EMBEDDINGS_MODEL", "text-embedding-3-small")

elif EMBEDDINGS_PROVIDER == "azure":
EMBEDDINGS_MODEL = get_env_variable("EMBEDDINGS_MODEL",
"text-embedding-3-small")
EMBEDDINGS_MODEL = get_env_variable("EMBEDDINGS_MODEL", "text-embedding-3-small")

elif EMBEDDINGS_PROVIDER == "huggingface":
EMBEDDINGS_MODEL = get_env_variable("EMBEDDINGS_MODEL",
"sentence-transformers/all-MiniLM-L6-v2")
EMBEDDINGS_MODEL = get_env_variable(
"EMBEDDINGS_MODEL", "sentence-transformers/all-MiniLM-L6-v2"
)

elif EMBEDDINGS_PROVIDER == "huggingfacetei":
EMBEDDINGS_MODEL = get_env_variable("EMBEDDINGS_MODEL",
"http://huggingfacetei:3000")
EMBEDDINGS_MODEL = get_env_variable(
"EMBEDDINGS_MODEL", "http://huggingfacetei:3000"
)

elif EMBEDDINGS_PROVIDER == "ollama":
EMBEDDINGS_MODEL = get_env_variable("EMBEDDINGS_MODEL", "nomic-embed-text")
Expand All @@ -197,14 +216,25 @@ def init_embeddings(provider, model):

logger.info(f"Initialized embeddings of type: {type(embeddings)}")

## Vector store

vector_store = get_vector_store(
# Vector store
if VECTOR_DB_TYPE == "pgvector":
vector_store = get_vector_store(
connection_string=CONNECTION_STRING,
embeddings=embeddings,
collection_name=COLLECTION_NAME,
mode="async",
)
)
elif VECTOR_DB_TYPE == "atlas-mongo":
# atlas-mongo vector:
vector_store = get_vector_store(
connection_string=ATLAS_MONGO_DB_URI,
embeddings=embeddings,
collection_name=MONGO_VECTOR_COLLECTION,
mode="atlas-mongo",
)
else:
raise ValueError(f"Unsupported vector store type: {VECTOR_DB_TYPE}")

retriever = vector_store.as_retriever()

known_source_ext = [
Expand Down
17 changes: 11 additions & 6 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,14 +60,16 @@
# RAG_EMBEDDING_MODEL,
# RAG_EMBEDDING_MODEL_DEVICE_TYPE,
# RAG_TEMPLATE,
VECTOR_DB_TYPE,
)


@asynccontextmanager
async def lifespan(app: FastAPI):
# Startup logic goes here
await PSQLDatabase.get_pool() # Initialize the pool
await ensure_custom_id_index_on_embedding()
if VECTOR_DB_TYPE == "pgvector":
await PSQLDatabase.get_pool() # Initialize the pool
await ensure_custom_id_index_on_embedding()

yield

Expand Down Expand Up @@ -105,7 +107,10 @@ async def get_all_ids():


def isHealthOK():
return pg_health_check()
if VECTOR_DB_TYPE == "pgvector":
return pg_health_check()
else:
return True


@app.get("/health")
Expand Down Expand Up @@ -137,7 +142,7 @@ async def get_documents_by_ids(ids: list[str] = Query(...)):


@app.delete("/documents")
async def delete_documents(ids: list[str]):
async def delete_documents(ids: list[str] = Query(...)):
try:
if isinstance(vector_store, AsyncPgVector):
existing_ids = await vector_store.get_all_ids()
Expand Down Expand Up @@ -497,11 +502,11 @@ async def query_embeddings_by_file_ids(body: QueryMultipleBody):
vector_store.similarity_search_with_score_by_vector,
embedding,
k=body.k,
filter={"custom_id": {"$in": body.file_ids}},
filter={"file_id": {"$in": body.file_ids}},
)
else:
documents = vector_store.similarity_search_with_score_by_vector(
embedding, k=body.k, filter={"custom_id": {"$in": body.file_ids}}
embedding, k=body.k, filter={"file_id": {"$in": body.file_ids}}
)

return documents
Expand Down
2 changes: 2 additions & 0 deletions requirements.lite.txt
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,5 @@ python-multipart==0.0.9
aiofiles==23.2.1
rapidocr-onnxruntime==1.3.17
opencv-python-headless==4.9.0.80
pymongo==4.6.3
langchain-mongodb==0.1.3
2 changes: 2 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,5 @@ sentence_transformers==2.5.1
aiofiles==23.2.1
rapidocr-onnxruntime==1.3.17
opencv-python-headless==4.9.0.80
pymongo==4.6.3
langchain-mongodb==0.1.3
Loading