Skip to content

Commit

Permalink
feat: update airport dataset and add search endpoint (#14)
Browse files Browse the repository at this point in the history
  • Loading branch information
Yuan325 authored Oct 24, 2023
1 parent d36fd29 commit baabfe4
Show file tree
Hide file tree
Showing 8 changed files with 1,655 additions and 7,727 deletions.
9,212 changes: 1,513 additions & 7,699 deletions data/airport_dataset.csv

Large diffs are not rendered by default.

23 changes: 19 additions & 4 deletions extension_service/app/app_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,10 @@ def test_hello_world(app):
assert response.json() == {"message": "Hello World"}


def test_get_amenity(app):
def test_get_airport(app):
with TestClient(app) as client:
response = client.get(
"/amenities",
"/airports",
params={
"id": 1,
},
Expand All @@ -63,10 +63,25 @@ def test_get_amenity(app):
assert output[0]


def test_get_airport(app):
def test_airports_semantic_lookup(app):
with TestClient(app) as client:
response = client.get(
"/airports",
"/airports/semantic_lookup",
params={
"query": "What is the airport in san francisco.",
"top_k": 5,
},
)
assert response.status_code == 200
output = response.json()
assert len(output) == 5
assert output[0]


def test_get_amenity(app):
with TestClient(app) as client:
response = client.get(
"/amenities",
params={
"id": 1,
},
Expand Down
24 changes: 18 additions & 6 deletions extension_service/app/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,24 @@ async def root():
return {"message": "Hello World"}


@routes.get("/airports")
async def get_airport(id: int, request: Request):
ds: datastore.Client = request.app.state.datastore
results = await ds.get_airport(id)
return results


@routes.get("/airports/semantic_lookup")
async def airports_semantic_lookup(query: str, top_k: int, request: Request):
ds: datastore.Client = request.app.state.datastore

embed_service: Embeddings = request.app.state.embed_service
query_embedding = embed_service.embed_query(query)

results = await ds.airports_semantic_lookup(query_embedding, 0.7, top_k)
return results


@routes.get("/amenities")
async def get_amenity(id: int, request: Request):
ds: datastore.Client = request.app.state.datastore
Expand All @@ -41,10 +59,4 @@ async def amenities_search(query: str, top_k: int, request: Request):
query_embedding = embed_service.embed_query(query)

results = await ds.amenities_search(query_embedding, 0.7, top_k)


@routes.get("/airports")
async def get_airport(id: int, request: Request):
ds: datastore.Client = request.app.state.datastore
results = await ds.get_airport(id)
return results
12 changes: 9 additions & 3 deletions extension_service/datastore/datastore.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

from abc import ABC, abstractmethod
from typing import Any, Dict, Generic, List, Optional, Tuple, TypeVar
from typing import Any, Dict, Generic, Optional, TypeVar

import models

Expand Down Expand Up @@ -49,20 +49,26 @@ async def initialize_data(
self,
airports: list[models.Airport],
amenities: list[models.Amenity],
flights: List[models.Flight],
flights: list[models.Flight],
) -> None:
pass

@abstractmethod
async def export_data(
self,
) -> tuple[list[models.Airport], list[models.Amenity], List[models.Flight]]:
) -> tuple[list[models.Airport], list[models.Amenity], list[models.Flight]]:
pass

@abstractmethod
async def get_airport(self, id: int) -> Optional[models.Airport]:
raise NotImplementedError("Subclass should implement this!")

@abstractmethod
async def airports_semantic_lookup(
self, query_embedding: list[float], similarity_threshold: float, top_k: int
) -> Optional[list[models.Airport]]:
raise NotImplementedError("Subclass should implement this!")

@abstractmethod
async def get_amenity(self, id: int) -> list[Dict[str, Any]]:
raise NotImplementedError("Subclass should implement this!")
Expand Down
45 changes: 38 additions & 7 deletions extension_service/datastore/providers/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

import asyncio
from ipaddress import IPv4Address, IPv6Address
from typing import Any, Dict, List, Literal, Optional
from typing import Any, Dict, Literal, Optional

import asyncpg
from pgvector.asyncpg import register_vector
Expand Down Expand Up @@ -67,7 +67,7 @@ async def initialize_data(
self,
airports: list[models.Airport],
amenities: list[models.Amenity],
flights: List[models.Flight],
flights: list[models.Flight],
) -> None:
async with self.__pool.acquire() as conn:
# If the table already exists, drop it to avoid conflicts
Expand Down Expand Up @@ -119,14 +119,19 @@ async def initialize_data(
iata TEXT,
name TEXT,
city TEXT,
country TEXT
country TEXT,
content TEXT NOT NULL,
embedding vector(768) NOT NULL
)
"""
)
# Insert all the data
await conn.executemany(
"""INSERT INTO airports VALUES ($1, $2, $3, $4, $5)""",
[(a.id, a.iata, a.name, a.city, a.country) for a in airports],
"""INSERT INTO airports VALUES ($1, $2, $3, $4, $5, $6, $7)""",
[
(a.id, a.iata, a.name, a.city, a.country, a.content, a.embedding)
for a in airports
],
)

# If the table already exists, drop it to avoid conflicts
Expand All @@ -142,8 +147,8 @@ async def initialize_data(
terminal TEXT,
category TEXT,
hour TEXT,
content TEXT,
embedding vector(768)
content TEXT NOT NULL,
embedding vector(768) NOT NULL
)
"""
)
Expand Down Expand Up @@ -198,6 +203,32 @@ async def get_airport(self, id: int) -> Optional[models.Airport]:
result = models.Airport.model_validate(dict(result))
return result

async def airports_semantic_lookup(
self, query_embedding: list[float], similarity_threshold: float, top_k: int
) -> Optional[list[models.Airport]]:
results = await self.__pool.fetch(
"""
SELECT id, iata, name, city, country
FROM (
SELECT id, iata, name, city, country, 1 - (embedding <=> $1) AS similarity
FROM airports
WHERE 1 - (embedding <=> $1) > $2
ORDER BY similarity DESC
LIMIT $3
) AS sorted_airports
""",
query_embedding,
similarity_threshold,
top_k,
timeout=10,
)

if results is []:
return None

results = [models.Airport.model_validate(dict(r)) for r in results]
return results

async def get_amenity(self, id: int) -> list[Dict[str, Any]]:
results = await self.__pool.fetch(
"""
Expand Down
51 changes: 45 additions & 6 deletions extension_service/datastore/providers/postgres_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,13 @@ class MockAsyncpgPool(asyncpg.Pool):
def __init__(self, mocks: Dict[str, MockRecord]):
self.mocks = mocks

async def fetch(self, query, *args):
return self.mocks.get(query.strip())
async def fetch(self, query, *args, timeout=None):
query = " ".join(q.strip() for q in query.splitlines()).strip()
return self.mocks.get(query)

async def fetchrow(self, query, *args, timeout=None):
query = " ".join(q.strip() for q in query.splitlines()).strip()
return self.mocks.get(query)


async def mock_postgres_provider(mocks: Dict[str, MockRecord]) -> postgres.Client:
Expand All @@ -51,6 +56,31 @@ async def mock_postgres_provider(mocks: Dict[str, MockRecord]) -> postgres.Clien

@pytest.mark.asyncio
async def test_get_airport():
mockRecord = MockRecord(
[
("id", 1),
("iata", "FOO"),
("name", "Foo Bar"),
("city", "baz"),
("country", "bundy"),
]
)
query = "SELECT id, iata, name, city, country FROM airports WHERE id=$1"
mocks = {query: mockRecord}
mockCl = await mock_postgres_provider(mocks)
res = await mockCl.get_airport(1)
expected_res = models.Airport(
id=1,
iata="FOO",
name="Foo Bar",
city="baz",
country="bundy",
)
assert res == expected_res


@pytest.mark.asyncio
async def test_airports_semantic_lookup():
mockRecord = [
MockRecord(
[
Expand All @@ -62,11 +92,20 @@ async def test_get_airport():
]
)
]
mocks = {
"SELECT id, iata, name, city, country FROM airports WHERE id=$1": mockRecord
}
query = """
SELECT id, iata, name, city, country
FROM (
SELECT id, iata, name, city, country, 1 - (embedding <=> $1) AS similarity
FROM airports
WHERE 1 - (embedding <=> $1) > $2
ORDER BY similarity DESC
LIMIT $3
) AS sorted_airports
"""
query = " ".join(q.strip() for q in query.splitlines()).strip()
mocks = {query: mockRecord}
mockCl = await mock_postgres_provider(mocks)
res = await mockCl.get_airport(1)
res = await mockCl.airports_semantic_lookup(1, 0.7, 1)
expected_res = [
models.Airport(
id=1,
Expand Down
13 changes: 12 additions & 1 deletion extension_service/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,29 @@
import ast
import datetime
from decimal import Decimal
from typing import List
from typing import Optional

from numpy import float32
from pydantic import BaseModel, ConfigDict, FieldValidationInfo, field_validator


class Airport(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True)

id: int
iata: str
name: str
city: str
country: str
content: Optional[str] = None
embedding: Optional[list[float32]] = None

@field_validator("embedding", mode="before")
def validate(cls, v):
if type(v) == str:
v = ast.literal_eval(v)
v = [float32(f) for f in v]
return v


class Amenity(BaseModel):
Expand Down
2 changes: 1 addition & 1 deletion extension_service/run_database_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ async def main():
await ds.close()

with open("../data/airport_dataset.csv.new", "w") as f:
col_names = ["id", "iata", "name", "city", "country"]
col_names = ["id", "iata", "name", "city", "country", "content", "embedding"]
writer = csv.DictWriter(f, col_names, delimiter=",")
writer.writeheader()
for a in airports:
Expand Down

0 comments on commit baabfe4

Please sign in to comment.