Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🚀 feat: Add Atlas MongoDB as an option for Vector Store #21

Merged
merged 16 commits into from
May 11, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 12 additions & 3 deletions config.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ async def dispatch(self, request, call_next):
RAG_AZURE_OPENAI_ENDPOINT = get_env_variable("RAG_AZURE_OPENAI_ENDPOINT", AZURE_OPENAI_ENDPOINT).rstrip("/")
HF_TOKEN = get_env_variable("HF_TOKEN", "")
OLLAMA_BASE_URL = get_env_variable("OLLAMA_BASE_URL", "http://ollama:11434")

ATLAS_MONGO_DB_URI = get_env_variable("ATLAS_MONGO_DB_URI", "")

## Embeddings

Expand Down Expand Up @@ -197,11 +197,20 @@ def init_embeddings(provider, model):

## Vector store

# This was the pgvector:
# vector_store = get_vector_store(
# connection_string=CONNECTION_STRING,
# embeddings=embeddings,
# collection_name=COLLECTION_NAME,
# mode="async",
# )

# new atlas-mongo vector:
vector_store = get_vector_store(
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i know this is draft but some way to initialize one or the other would be good

connection_string=CONNECTION_STRING,
connection_string=ATLAS_MONGO_DB_URI,
embeddings=embeddings,
collection_name=COLLECTION_NAME,
mode="async",
mode="atlas-mongo",
)
retriever = vector_store.as_retriever()

Expand Down
11 changes: 6 additions & 5 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@
)

from models import DocumentResponse, StoreDocument, QueryRequestBody, QueryMultipleBody
from psql import PSQLDatabase, ensure_custom_id_index_on_embedding, \
pg_health_check
# from psql import PSQLDatabase, ensure_custom_id_index_on_embedding, \
# pg_health_check
from middleware import security_middleware
from pgvector_routes import router as pgvector_router
from parsers import process_documents
Expand Down Expand Up @@ -57,8 +57,8 @@
@asynccontextmanager
async def lifespan(app: FastAPI):
# Startup logic goes here
await PSQLDatabase.get_pool() # Initialize the pool
await ensure_custom_id_index_on_embedding()
# await PSQLDatabase.get_pool() # Initialize the pool
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know it's draft but here would be good to also initialize differently, could probably import functions from a different module

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can start working on a async implementation of the MongoDB code and that might be a good time to do these things

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can start working on a async implementation of the MongoDB code and that might be a good time to do these things

awesome thanks for doing that!

# await ensure_custom_id_index_on_embedding()

yield

Expand Down Expand Up @@ -95,7 +95,8 @@ async def get_all_ids():


def isHealthOK():
return pg_health_check()
# return pg_health_check()
return True


@app.get("/health")
Expand Down
2 changes: 2 additions & 0 deletions requirements.lite.txt
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,5 @@ python-jose==3.3.0
asyncpg==0.29.0
python-multipart==0.0.9
aiofiles==23.2.1
pymongo==4.6.3
langchain-mongodb==0.1.3
2 changes: 2 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,5 @@ asyncpg==0.29.0
python-multipart==0.0.9
sentence_transformers==2.5.1
aiofiles==23.2.1
pymongo==4.6.3
langchain-mongodb==0.1.3
48 changes: 48 additions & 0 deletions store.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,14 @@
from langchain_core.runnables.config import run_in_executor
from sqlalchemy.orm import Session

from langchain_mongodb import MongoDBAtlasVectorSearch
from langchain_core.embeddings import Embeddings
from typing import (
List,
Optional,
Tuple,
)
import copy
class ExtendedPgVector(PGVector):

def get_all_ids(self) -> list[str]:
Expand Down Expand Up @@ -67,3 +75,43 @@ async def delete(
collection_only: bool = False
) -> None:
await run_in_executor(None, self._delete_multiple, ids, collection_only)

class AtlasMongoVector(MongoDBAtlasVectorSearch):
@property
def embedding_function(self) -> Embeddings:
return self.embeddings

def similarity_search_with_score_by_vector(
self,
embedding: List[float],
k: int = 4,
filter: Optional[dict] = None,
**kwargs: Any,
) -> List[Tuple[Document, float]]:
docs = self._similarity_search_with_score(
embedding,
k=k,
pre_filter=filter,
post_filter_pipeline=None,
**kwargs,
)
# remove `metadata._id` since MongoDB ObjectID is not serializable
# Process the documents to remove metadata._id
processed_documents: List[Tuple[Document, float]] = []
for document, score in docs:
# Make a deep copy of the document to avoid mutating the original
doc_copy = copy.deepcopy(document.__dict__) # If Document is a dataclass or similar; adjust as needed

# Remove _id field from metadata if it exists
if 'metadata' in doc_copy and '_id' in doc_copy['metadata']:
del doc_copy['metadata']['_id']

# Create a new Document instance without the _id
new_document = Document(**doc_copy) # Adjust this line according to how you instantiate your Document

# Append the new document and score to the list as a tuple
processed_documents.append((new_document, score))
return processed_documents

def get_all_ids(self) -> list[str]:
return run_in_executor(None)
9 changes: 8 additions & 1 deletion store_factory.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from langchain_community.embeddings import OpenAIEmbeddings

from store import AsyncPgVector, ExtendedPgVector

from store import AtlasMongoVector
from pymongo import MongoClient

def get_vector_store(
connection_string: str,
Expand All @@ -21,9 +22,15 @@ def get_vector_store(
embedding_function=embeddings,
collection_name=collection_name,
)
elif mode == "atlas-mongo":
mongo_db = MongoClient(connection_string).get_database()
mong_collection = mongo_db[collection_name]
return AtlasMongoVector(collection=mong_collection, embedding=embeddings)

else:
raise ValueError("Invalid mode specified. Choose 'sync' or 'async'.")


async def create_index_if_not_exists(conn, table_name: str, column_name: str):
# Construct index name conventionally
index_name = f"idx_{table_name}_{column_name}"
Expand Down