Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add Firestore vector embedding support #421

Merged
merged 26 commits into from
Jul 25, 2024
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
699d75c
Add vector indexing for amenities and policies collections
ferdeleong Jun 20, 2024
f6f4202
Replace composite vector indexing for single-field indexing
ferdeleong Jun 21, 2024
14f9794
Save query streams in docs variable to iterate
ferdeleong Jun 21, 2024
afeec37
Add async query
ferdeleong Jul 15, 2024
e2fba59
Delete embedding field from dictionary
ferdeleong Jul 15, 2024
828d500
Explicitly state wanted fields in dictionaries
ferdeleong Jul 15, 2024
45cabb8
Delete content field from amenity dictionary
ferdeleong Jul 15, 2024
ed37d40
cast embeddings back to list type to ensure model validation
ferdeleong Jul 15, 2024
fc6d802
Add conditional since embedding is optional
ferdeleong Jul 16, 2024
88ea1fa
Fix requirements
ferdeleong Jul 16, 2024
4cc8fb3
Fix lints and add empty list to avoid vector type error
ferdeleong Jul 17, 2024
0604c01
Add missing amenity fields to collection creation
ferdeleong Jul 17, 2024
1df3752
Add missing amenity fields to collection creation
ferdeleong Jul 17, 2024
f0cb636
Delete conditional to cast back to list type for amenity and policy
ferdeleong Jul 18, 2024
4bf7621
Enable delete vector indexes if they exist to repopulare the dabatase
ferdeleong Jul 18, 2024
f036c4f
Delete conditional to cast back to list type for get amenity
ferdeleong Jul 18, 2024
c5b3dba
Pin firestore version to a commit
ferdeleong Jul 18, 2024
d914e9c
Add await to delete collections tasks
ferdeleong Jul 18, 2024
8cc18ac
Simplify index existance conditional
ferdeleong Jul 18, 2024
d3f434c
Delete tuple casting error
ferdeleong Jul 18, 2024
f4df0a6
delete collections and indexes as private functions
ferdeleong Jul 22, 2024
e666fec
Abstract create_vecotr_index function
ferdeleong Jul 23, 2024
76ed53b
remove gather task from await functions
ferdeleong Jul 23, 2024
4ff6c42
Fix comment lints
ferdeleong Jul 23, 2024
39472b5
Update list indexes to return a dictionary
ferdeleong Jul 25, 2024
4f66797
Merge branch 'main' into Firestore-Vector-01
ferdeleong Jul 25, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
238 changes: 230 additions & 8 deletions retrieval_service/datastore/providers/firestore.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,10 @@

from google.cloud.firestore import AsyncClient # type: ignore
from google.cloud.firestore_v1.async_collection import AsyncCollectionReference
from google.cloud.firestore_v1.async_query import AsyncQuery
from google.cloud.firestore_v1.base_query import FieldFilter
from google.cloud.firestore_v1.base_vector_query import DistanceMeasure
from google.cloud.firestore_v1.vector import Vector
from pydantic import BaseModel

import models
Expand Down Expand Up @@ -55,7 +58,7 @@ async def initialize_data(
policies: list[models.Policy],
) -> None:
async def delete_collections(collection_list: list[AsyncCollectionReference]):
# Checks if colelction exists and deletes all documents
# Checks if collection exists and deletes all documents
delete_tasks = []
for collection_ref in collection_list:
collection_exists = collection_ref.limit(1).stream()
Expand All @@ -76,6 +79,68 @@ async def delete_collections(collection_list: list[AsyncCollectionReference]):
[airports_ref, amenities_ref, flights_ref, policies_ref]
)

async def delete_indexes(index_list):
ferdeleong marked this conversation as resolved.
Show resolved Hide resolved
# Check if the collection exists and deletes all indexes
delete_tasks = []
for index in index_list:
if index:
delete_vector_index = [
"gcloud",
"alpha",
"firestore",
"indexes",
"composite",
"delete",
index,
"--database=(default)",
"--quiet", # Added to suppress delete warning
]

delete_vector_index_process = await asyncio.create_subprocess_exec(
*delete_vector_index,
)
delete_tasks.append(
asyncio.create_task(delete_vector_index_process.wait())
)
await asyncio.gather(*delete_tasks)

# List indexes and retrieve name field (file-path)
list_vector_index = [
"gcloud",
"alpha",
"firestore",
"indexes",
"composite",
"list",
"--database=(default)",
"--format=value(name)", # prints name field
]

list_vector_index_process = await asyncio.create_subprocess_exec(
*list_vector_index,
stdout=asyncio.subprocess.PIPE,
)

# Capture output and ignore stderr
stdout, __ = await list_vector_index_process.communicate()

# Decode and format output
indexes = stdout.decode().strip().split("\n")
ferdeleong marked this conversation as resolved.
Show resolved Hide resolved

# Check if the indexes already exist; if so, delete indexes
# Extract collection and index-id from file path
# Assign the index-id to the corresponding collection
if indexes == [""]:
pass
else:
ferdeleong marked this conversation as resolved.
Show resolved Hide resolved
collections = {"amenities": "", "policies": ""}
for line in indexes:
collection, index_id = line.split("/")[-3], line.split("/")[-1]
ferdeleong marked this conversation as resolved.
Show resolved Hide resolved
collections[collection] = index_id
amenities_ref = collections["amenities"]
policies_ref = collections["policies"]
await delete_indexes([amenities_ref, policies_ref])
ferdeleong marked this conversation as resolved.
Show resolved Hide resolved

# initialize collections
create_airports_tasks = []
for airport in airports:
Expand Down Expand Up @@ -105,8 +170,80 @@ async def delete_collections(collection_list: list[AsyncCollectionReference]):
"terminal": amenity.terminal,
"category": amenity.category,
"hour": amenity.hour,
# Firebase does not support datetime.time type
"sunday_start_hour": (
str(amenity.sunday_start_hour)
ferdeleong marked this conversation as resolved.
Show resolved Hide resolved
ferdeleong marked this conversation as resolved.
Show resolved Hide resolved
if amenity.sunday_start_hour
else None
),
"sunday_end_hour": (
str(amenity.sunday_end_hour)
if amenity.sunday_end_hour
else None
),
"monday_start_hour": (
str(amenity.monday_start_hour)
if amenity.monday_start_hour
else None
),
"monday_end_hour": (
str(amenity.monday_end_hour)
if amenity.monday_end_hour
else None
),
"tuesday_start_hour": (
str(amenity.tuesday_start_hour)
if amenity.tuesday_start_hour
else None
),
"tuesday_end_hour": (
str(amenity.tuesday_end_hour)
if amenity.tuesday_end_hour
else None
),
"wednesday_start_hour": (
str(amenity.wednesday_start_hour)
if amenity.wednesday_start_hour
else None
),
"wednesday_end_hour": (
str(amenity.wednesday_end_hour)
if amenity.wednesday_end_hour
else None
),
"thursday_start_hour": (
str(amenity.thursday_start_hour)
if amenity.thursday_start_hour
else None
),
"thursday_end_hour": (
str(amenity.thursday_end_hour)
if amenity.thursday_end_hour
else None
),
"friday_start_hour": (
str(amenity.friday_start_hour)
if amenity.friday_start_hour
else None
),
"friday_end_hour": (
str(amenity.friday_end_hour)
if amenity.friday_end_hour
else None
),
"saturday_start_hour": (
str(amenity.saturday_start_hour)
if amenity.saturday_start_hour
else None
),
"saturday_end_hour": (
str(amenity.saturday_end_hour)
if amenity.saturday_end_hour
else None
),
"content": amenity.content,
"embedding": amenity.embedding,
# Vector type does not support None value
"embedding": Vector(tuple(amenity.embedding or [])),
ferdeleong marked this conversation as resolved.
Show resolved Hide resolved
}
)
)
Expand All @@ -122,8 +259,12 @@ async def delete_collections(collection_list: list[AsyncCollectionReference]):
"flight_number": flight.flight_number,
"departure_airport": flight.departure_airport,
"arrival_airport": flight.arrival_airport,
"departure_time": flight.departure_time,
"arrival_time": flight.arrival_time,
"departure_time": flight.departure_time.strftime(
"%Y-%m-%d %H:%M:%S"
),
"arrival_time": flight.arrival_time.strftime(
"%Y-%m-%d %H:%M:%S"
),
"departure_gate": flight.departure_gate,
"arrival_gate": flight.arrival_gate,
}
Expand All @@ -142,12 +283,49 @@ async def delete_collections(collection_list: list[AsyncCollectionReference]):
.set(
{
"content": policy.content,
"embedding": policy.embedding,
# Vector type does not accept None value
"embedding": Vector(policy.embedding or []),
}
)
)
await asyncio.gather(*create_policies_tasks)

# Initialize single-field vector indexes
create_amenities_vector_index = [
ferdeleong marked this conversation as resolved.
Show resolved Hide resolved
"gcloud",
"alpha",
"firestore",
"indexes",
"composite",
"create",
"--collection-group=amenities",
"--query-scope=COLLECTION",
'--field-config=field-path=embedding,vector-config={"dimension":768,"flat":"{}"}',
"--database=(default)",
]
create_amenities_process = await asyncio.create_subprocess_exec(
*create_amenities_vector_index,
)
await create_amenities_process.wait()

create_policies_vector_index = [
"gcloud",
"alpha",
"firestore",
"indexes",
"composite",
"create",
"--collection-group=policies",
"--query-scope=COLLECTION",
'--field-config=field-path=embedding,vector-config={"dimension":768,"flat":"{}"}',
"--database=(default)",
]

create_policies_process = await asyncio.create_subprocess_exec(
*create_policies_vector_index,
)
await create_policies_process.wait()

async def export_data(
self,
) -> tuple[
Expand All @@ -171,6 +349,7 @@ async def export_data(
async for doc in amenities_docs:
amenity_dict = doc.to_dict()
amenity_dict["id"] = doc.id
amenity_dict["embedding"] = list(amenity_dict["embedding"])
amenities.append(models.Amenity.model_validate(amenity_dict))

flights = []
Expand All @@ -183,7 +362,9 @@ async def export_data(
async for doc in policies_docs:
policy_dict = doc.to_dict()
policy_dict["id"] = doc.id
policy_dict["embedding"] = list(policy_dict["embedding"])
policies.append(models.Policy.model_validate(policy_dict))

return airports, amenities, flights, policies

async def get_airport_by_id(self, id: int) -> Optional[models.Airport]:
Expand Down Expand Up @@ -234,12 +415,38 @@ async def get_amenity(self, id: int) -> Optional[models.Amenity]:
)
amenity_doc = await query.get()
amenity_dict = amenity_doc.to_dict() | {"id": amenity_doc.id}
amenity_dict["embedding"] = list(amenity_dict["embedding"])
return models.Amenity.model_validate(amenity_dict)
ferdeleong marked this conversation as resolved.
Show resolved Hide resolved

async def amenities_search(
self, query_embedding: list[float], similarity_threshold: float, top_k: int
) -> list[Any]:
raise NotImplementedError("Semantic search not yet supported in Firestore.")
collection = AsyncQuery(self.__client.collection("amenities"))
query_vector = Vector(query_embedding)
# Using the same similarity metric to the embedding model's training method
# produce the most accurate result
distance_measure = DistanceMeasure.DOT_PRODUCT
query = collection.find_nearest(
vector_field="embedding",
query_vector=query_vector,
distance_measure=distance_measure,
limit=top_k,
)

docs = query.stream()
amenities = []
async for doc in docs:
amenity_dict = {
"id": doc.id,
"category": doc.get("category"),
"description": doc.get("description"),
"hour": doc.get("hour"),
"location": doc.get("location"),
"name": doc.get("name"),
"terminal": doc.get("terminal"),
}
amenities.append(amenity_dict)
return amenities

async def get_flight(self, flight_id: int) -> Optional[models.Flight]:
query = self.__client.collection("flights").where(
Expand Down Expand Up @@ -326,8 +533,23 @@ async def list_tickets(

async def policies_search(
self, query_embedding: list[float], similarity_threshold: float, top_k: int
) -> list[str]:
raise NotImplementedError("Semantic search not yet supported in Firestore.")
) -> list[Any]:
collection = AsyncQuery(self.__client.collection("policies"))
query_vector = Vector(query_embedding)
distance_measure = DistanceMeasure.DOT_PRODUCT
ferdeleong marked this conversation as resolved.
Show resolved Hide resolved
ferdeleong marked this conversation as resolved.
Show resolved Hide resolved
query = collection.find_nearest(
vector_field="embedding",
query_vector=query_vector,
distance_measure=distance_measure,
limit=top_k,
)
ferdeleong marked this conversation as resolved.
Show resolved Hide resolved

docs = query.stream()
policies = []
async for doc in docs:
ferdeleong marked this conversation as resolved.
Show resolved Hide resolved
policy_dict = {"id": doc.id, "content": doc.get("content")}
policies.append(policy_dict)
return policies

async def close(self):
self.__client.close()
2 changes: 1 addition & 1 deletion retrieval_service/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
asyncpg==0.29.0
fastapi==0.109.2
google-auth==2.32.0
google-cloud-firestore==2.14.0
google-cloud-firestore @ git+https://github.com/googleapis/python-firestore.git@2de16209409c9d9ba41d3444400e6a39ee1b2936
ferdeleong marked this conversation as resolved.
Show resolved Hide resolved
google-cloud-aiplatform==1.56.0
google-cloud-spanner==3.46.0
langchain-core==0.2.8
Expand Down