Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Clean up pypi_scout.api.main #10

Merged
merged 1 commit into from
Jun 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions frontend/app/utils/search.ts
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ export const handleSearch = async (
`${apiUrl}/search`,
{
query: query,
top_k: 40,
},
{
headers: {
Expand Down
17 changes: 15 additions & 2 deletions pypi_scout/api/data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ def load_dataset(self) -> Tuple[pl.DataFrame, pl.DataFrame]:
else:
raise ValueError(f"Unexpected value found for STORAGE_BACKEND: {self.config.STORAGE_BACKEND}") # noqa: TRY003

df_embeddings = self._drop_rows_from_embeddings_that_do_not_appear_in_packages(df_embeddings, df_packages)
return df_packages, df_embeddings

def _load_local_dataset(self) -> Tuple[pl.DataFrame, pl.DataFrame]:
Expand Down Expand Up @@ -56,10 +57,22 @@ def _load_blob_dataset(self) -> Tuple[pl.DataFrame, pl.DataFrame]:

return df_packages, df_embeddings

def _log_packages_dataset_info(self, df_packages: pl.DataFrame) -> None:
@staticmethod
def _log_packages_dataset_info(df_packages: pl.DataFrame) -> None:
logging.info(f"Finished loading the `packages` dataset. Number of rows in dataset: {len(df_packages):,}")
logging.info(df_packages.describe())

def _log_embeddings_dataset_info(self, df_embeddings: pl.DataFrame) -> None:
@staticmethod
def _log_embeddings_dataset_info(df_embeddings: pl.DataFrame) -> None:
logging.info(f"Finished loading the `embeddings` dataset. Number of rows in dataset: {len(df_embeddings):,}")
logging.info(df_embeddings.describe())

@staticmethod
def _drop_rows_from_embeddings_that_do_not_appear_in_packages(df_embeddings, df_packages):
# We only keep the packages in the vector dataset that also occur in the packages dataset.
# In theory, this should never drop something. But still good to keep as a fail-safe to prevent issues in the API.
logging.info("Dropping packages in the `embeddings` dataset that do not occur in the `packages` dataset...")
logging.info(f"Number of rows before dropping: {len(df_embeddings):,}...")
df_embeddings = df_embeddings.join(df_packages, on="name", how="semi")
logging.info(f"Number of rows after dropping: {len(df_embeddings):,}...")
return df_embeddings
44 changes: 4 additions & 40 deletions pypi_scout/api/main.py
Original file line number Diff line number Diff line change
@@ -1,40 +1,35 @@
import logging

import polars as pl
from dotenv import load_dotenv
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from sentence_transformers import SentenceTransformer
from slowapi import Limiter, _rate_limit_exceeded_handler
from slowapi.errors import RateLimitExceeded
from slowapi.util import get_remote_address
from starlette.requests import Request

from pypi_scout.api.data_loader import ApiDataLoader
from pypi_scout.api.models import QueryModel, SearchResponse
from pypi_scout.config import Config
from pypi_scout.embeddings.simple_vector_database import SimpleVectorDatabase
from pypi_scout.utils.logging import setup_logging
from pypi_scout.utils.score_calculator import calculate_score

# Setup logging
setup_logging()
logging.info("Initializing backend...")

# Initialize limiter
limiter = Limiter(key_func=get_remote_address)
app = FastAPI()
app.state.limiter = limiter
app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)

# Load environment variables and configuration
load_dotenv()
config = Config()

# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # Temporary wildcard for testing
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
Expand All @@ -44,28 +39,9 @@
df_packages, df_embeddings = data_loader.load_dataset()

model = SentenceTransformer(config.EMBEDDINGS_MODEL_NAME)

vector_database = SimpleVectorDatabase(embeddings_model=model, df_embeddings=df_embeddings)


class QueryModel(BaseModel):
query: str
top_k: int = config.N_RESULTS_TO_RETURN


class Match(BaseModel):
name: str
summary: str
similarity: float
weekly_downloads: int


class SearchResponse(BaseModel):
matches: list[Match]
warning: bool = False
warning_message: str = None


@app.post("/api/search", response_model=SearchResponse)
@limiter.limit("4/minute")
async def search(query: QueryModel, request: Request):
Expand All @@ -75,7 +51,7 @@ async def search(query: QueryModel, request: Request):
The top_k packages with the highest score are returned.
"""

if query.top_k > 60:
if query.top_k > 100:
raise HTTPException(status_code=400, detail="top_k cannot be larger than 100.")

logging.info(f"Searching for similar projects. Query: '{query.query}'")
Expand All @@ -85,18 +61,6 @@ async def search(query: QueryModel, request: Request):
f"Fetched the {len(df_matches)} most similar projects. Calculating the weighted scores and filtering..."
)

warning = False
warning_message = ""
matches_missing_in_local_dataset = df_matches.filter(pl.col("weekly_downloads").is_null())["name"].to_list()
if matches_missing_in_local_dataset:
warning = True
warning_message = (
f"The following entries have 'None' for 'weekly_downloads': {matches_missing_in_local_dataset}. "
"These entries were found in the vector database but not in the local dataset and have been excluded from the results."
)
logging.error(warning_message)
df_matches = df_matches.filter(~pl.col("name").is_in(matches_missing_in_local_dataset))

df_matches = calculate_score(
df_matches, weight_similarity=config.WEIGHT_SIMILARITY, weight_weekly_downloads=config.WEIGHT_WEEKLY_DOWNLOADS
)
Expand All @@ -107,4 +71,4 @@ async def search(query: QueryModel, request: Request):

logging.info(f"Returning the {len(df_matches)} best matches.")
df_matches = df_matches.select(["name", "similarity", "summary", "weekly_downloads"])
return SearchResponse(matches=df_matches.to_dicts(), warning=warning, warning_message=warning_message)
return SearchResponse(matches=df_matches.to_dicts())
19 changes: 19 additions & 0 deletions pypi_scout/api/models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from pydantic import BaseModel


class QueryModel(BaseModel):
query: str
top_k: int


class Match(BaseModel):
name: str
summary: str
similarity: float
weekly_downloads: int


class SearchResponse(BaseModel):
matches: list[Match]
warning: bool = False
warning_message: str = None
3 changes: 0 additions & 3 deletions pypi_scout/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,6 @@ class Config:
# Google Drive file ID for downloading the raw dataset.
GOOGLE_FILE_ID = "1IDJvCsq1gz0yUSXgff13pMl3nUk7zJzb"

# Number of top results to return for a query.
N_RESULTS_TO_RETURN = 40

# Fraction of the dataset to include in the vector database. This value determines the portion of top packages
# (sorted by weekly downloads) to include. Increase this value to include a larger portion of the dataset, up to 1.0 (100%).
# For reference, a value of 0.25 corresponds to including all PyPI packages with at least approximately 650 weekly downloads
Expand Down
Loading