From c6f1504db63e704cf631a9fc1ca6f03e6dead037 Mon Sep 17 00:00:00 2001 From: Florian Maas Date: Sun, 23 Jun 2024 20:03:44 +0200 Subject: [PATCH] cleaned up a bit --- frontend/app/utils/search.ts | 1 + pypi_scout/api/data_loader.py | 17 ++++++++++++-- pypi_scout/api/main.py | 44 ++++------------------------------- pypi_scout/api/models.py | 19 +++++++++++++++ pypi_scout/config.py | 3 --- 5 files changed, 39 insertions(+), 45 deletions(-) create mode 100644 pypi_scout/api/models.py diff --git a/frontend/app/utils/search.ts b/frontend/app/utils/search.ts index 88bfb28..c497acc 100644 --- a/frontend/app/utils/search.ts +++ b/frontend/app/utils/search.ts @@ -30,6 +30,7 @@ export const handleSearch = async ( `${apiUrl}/search`, { query: query, + top_k: 40, }, { headers: { diff --git a/pypi_scout/api/data_loader.py b/pypi_scout/api/data_loader.py index c4ea031..9eb246a 100644 --- a/pypi_scout/api/data_loader.py +++ b/pypi_scout/api/data_loader.py @@ -19,6 +19,7 @@ def load_dataset(self) -> Tuple[pl.DataFrame, pl.DataFrame]: else: raise ValueError(f"Unexpected value found for STORAGE_BACKEND: {self.config.STORAGE_BACKEND}") # noqa: TRY003 + df_embeddings = self._drop_rows_from_embeddings_that_do_not_appear_in_packages(df_embeddings, df_packages) return df_packages, df_embeddings def _load_local_dataset(self) -> Tuple[pl.DataFrame, pl.DataFrame]: @@ -56,10 +57,22 @@ def _load_blob_dataset(self) -> Tuple[pl.DataFrame, pl.DataFrame]: return df_packages, df_embeddings - def _log_packages_dataset_info(self, df_packages: pl.DataFrame) -> None: + @staticmethod + def _log_packages_dataset_info(df_packages: pl.DataFrame) -> None: logging.info(f"Finished loading the `packages` dataset. Number of rows in dataset: {len(df_packages):,}") logging.info(df_packages.describe()) - def _log_embeddings_dataset_info(self, df_embeddings: pl.DataFrame) -> None: + @staticmethod + def _log_embeddings_dataset_info(df_embeddings: pl.DataFrame) -> None: logging.info(f"Finished loading the `embeddings` dataset. Number of rows in dataset: {len(df_embeddings):,}") logging.info(df_embeddings.describe()) + + @staticmethod + def _drop_rows_from_embeddings_that_do_not_appear_in_packages(df_embeddings, df_packages): + # We only keep the packages in the vector dataset that also occur in the packages dataset. + # In theory, this should never drop something. But still good to keep as a fail-safe to prevent issues in the API. + logging.info("Dropping packages in the `embeddings` dataset that do not occur in the `packages` dataset...") + logging.info(f"Number of rows before dropping: {len(df_embeddings):,}...") + df_embeddings = df_embeddings.join(df_packages, on="name", how="semi") + logging.info(f"Number of rows after dropping: {len(df_embeddings):,}...") + return df_embeddings diff --git a/pypi_scout/api/main.py b/pypi_scout/api/main.py index 5e70afb..f54c585 100644 --- a/pypi_scout/api/main.py +++ b/pypi_scout/api/main.py @@ -1,10 +1,8 @@ import logging -import polars as pl from dotenv import load_dotenv from fastapi import FastAPI, HTTPException from fastapi.middleware.cors import CORSMiddleware -from pydantic import BaseModel from sentence_transformers import SentenceTransformer from slowapi import Limiter, _rate_limit_exceeded_handler from slowapi.errors import RateLimitExceeded @@ -12,29 +10,26 @@ from starlette.requests import Request from pypi_scout.api.data_loader import ApiDataLoader +from pypi_scout.api.models import QueryModel, SearchResponse from pypi_scout.config import Config from pypi_scout.embeddings.simple_vector_database import SimpleVectorDatabase from pypi_scout.utils.logging import setup_logging from pypi_scout.utils.score_calculator import calculate_score -# Setup logging setup_logging() logging.info("Initializing backend...") -# Initialize limiter limiter = Limiter(key_func=get_remote_address) app = FastAPI() app.state.limiter = limiter app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler) -# Load environment variables and configuration load_dotenv() config = Config() -# Add CORS middleware app.add_middleware( CORSMiddleware, - allow_origins=["*"], # Temporary wildcard for testing + allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], @@ -44,28 +39,9 @@ df_packages, df_embeddings = data_loader.load_dataset() model = SentenceTransformer(config.EMBEDDINGS_MODEL_NAME) - vector_database = SimpleVectorDatabase(embeddings_model=model, df_embeddings=df_embeddings) -class QueryModel(BaseModel): - query: str - top_k: int = config.N_RESULTS_TO_RETURN - - -class Match(BaseModel): - name: str - summary: str - similarity: float - weekly_downloads: int - - -class SearchResponse(BaseModel): - matches: list[Match] - warning: bool = False - warning_message: str = None - - @app.post("/api/search", response_model=SearchResponse) @limiter.limit("4/minute") async def search(query: QueryModel, request: Request): @@ -75,7 +51,7 @@ async def search(query: QueryModel, request: Request): The top_k packages with the highest score are returned. """ - if query.top_k > 60: + if query.top_k > 100: raise HTTPException(status_code=400, detail="top_k cannot be larger than 100.") logging.info(f"Searching for similar projects. Query: '{query.query}'") @@ -85,18 +61,6 @@ async def search(query: QueryModel, request: Request): f"Fetched the {len(df_matches)} most similar projects. Calculating the weighted scores and filtering..." ) - warning = False - warning_message = "" - matches_missing_in_local_dataset = df_matches.filter(pl.col("weekly_downloads").is_null())["name"].to_list() - if matches_missing_in_local_dataset: - warning = True - warning_message = ( - f"The following entries have 'None' for 'weekly_downloads': {matches_missing_in_local_dataset}. " - "These entries were found in the vector database but not in the local dataset and have been excluded from the results." - ) - logging.error(warning_message) - df_matches = df_matches.filter(~pl.col("name").is_in(matches_missing_in_local_dataset)) - df_matches = calculate_score( df_matches, weight_similarity=config.WEIGHT_SIMILARITY, weight_weekly_downloads=config.WEIGHT_WEEKLY_DOWNLOADS ) @@ -107,4 +71,4 @@ async def search(query: QueryModel, request: Request): logging.info(f"Returning the {len(df_matches)} best matches.") df_matches = df_matches.select(["name", "similarity", "summary", "weekly_downloads"]) - return SearchResponse(matches=df_matches.to_dicts(), warning=warning, warning_message=warning_message) + return SearchResponse(matches=df_matches.to_dicts()) diff --git a/pypi_scout/api/models.py b/pypi_scout/api/models.py new file mode 100644 index 0000000..ed54bbb --- /dev/null +++ b/pypi_scout/api/models.py @@ -0,0 +1,19 @@ +from pydantic import BaseModel + + +class QueryModel(BaseModel): + query: str + top_k: int + + +class Match(BaseModel): + name: str + summary: str + similarity: float + weekly_downloads: int + + +class SearchResponse(BaseModel): + matches: list[Match] + warning: bool = False + warning_message: str = None diff --git a/pypi_scout/config.py b/pypi_scout/config.py index 57cb849..64bdcf3 100644 --- a/pypi_scout/config.py +++ b/pypi_scout/config.py @@ -38,9 +38,6 @@ class Config: # Google Drive file ID for downloading the raw dataset. GOOGLE_FILE_ID = "1IDJvCsq1gz0yUSXgff13pMl3nUk7zJzb" - # Number of top results to return for a query. - N_RESULTS_TO_RETURN = 40 - # Fraction of the dataset to include in the vector database. This value determines the portion of top packages # (sorted by weekly downloads) to include. Increase this value to include a larger portion of the dataset, up to 1.0 (100%). # For reference, a value of 0.25 corresponds to including all PyPI packages with at least approximately 650 weekly downloads