Skip to content

Commit

Permalink
more improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
florian committed Jun 15, 2024
1 parent f36ea6e commit 3ec5cc2
Show file tree
Hide file tree
Showing 6 changed files with 160 additions and 48 deletions.
4 changes: 2 additions & 2 deletions frontend/app/layout.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@ import "./globals.css";
const inter = Inter({ subsets: ["latin"] });

export const metadata: Metadata = {
title: "Create Next App",
description: "Generated by create next app",
title: "PyPi LLM Search",
description: "Find PyPi packages with natural language using LLM's",
};

export default function RootLayout({
Expand Down
49 changes: 36 additions & 13 deletions frontend/app/page.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ export default function Home() {
const [sortDirection, setSortDirection] = useState("desc");
const [loading, setLoading] = useState(false);
const [error, setError] = useState("");
const [infoBoxVisible, setInfoBoxVisible] = useState(false);

const handleSearch = async () => {
setLoading(true);
Expand All @@ -28,10 +29,8 @@ export default function Home() {
},
},
);
const sortedResults = response.data.matches.sort(
(a, b) => b.weekly_downloads - a.weekly_downloads,
);
setResults(sortedResults);
const fetchedResults = response.data.matches;
setResults(sortResults(fetchedResults, sortField, sortDirection));
} catch (error) {
setError("Error fetching search results.");
console.error("Error fetching search results:", error);
Expand All @@ -40,17 +39,20 @@ export default function Home() {
}
};

const sortResults = (field) => {
const direction =
sortField === field && sortDirection === "asc" ? "desc" : "asc";
const sorted = [...results].sort((a, b) => {
const sortResults = (data, field, direction) => {
return [...data].sort((a, b) => {
if (a[field] < b[field]) return direction === "asc" ? -1 : 1;
if (a[field] > b[field]) return direction === "asc" ? 1 : -1;
return 0;
});
setResults(sorted);
};

const handleSort = (field) => {
const direction =
sortField === field && sortDirection === "asc" ? "desc" : "asc";
setSortField(field);
setSortDirection(direction);
setResults(sortResults(results, field, direction));
};

return (
Expand Down Expand Up @@ -80,17 +82,38 @@ export default function Home() {
{error && <p className="text-red-500">{error}</p>}
</div>

<div className="w-full flex justify-center mt-6">
<button
className="w-[250px] p-2 border rounded bg-gray-300 hover:bg-gray-400 focus:outline-none focus:ring-2 focus:ring-gray-500"
onClick={() => setInfoBoxVisible(!infoBoxVisible)}
>
{infoBoxVisible ? "Hide Info" : "How does this work?"}
</button>
</div>

{infoBoxVisible && (
<div className="w-3/5 bg-white p-6 rounded-lg shadow-lg mt-4">
<h2 className="text-2xl font-bold mb-2">How does this work?</h2>
<p className="text-gray-700">
This application allows you to search for Python packages on PyPi
using natural language. So an example query would be "a package that
creates plots and beautiful visualizations". Once you click search,
your query will be matched against the summary and the first part of
the description of all PyPi packages with more than 50 weekly
downloads, and the 50 most similar results will be displayed in a
table below.
</p>
</div>
)}

{results.length > 0 && (
<div className="w-full flex justify-center mt-6">
<div className="w-11/12 bg-white p-6 rounded-lg shadow-lg flex flex-col items-center">
<p className="mb-4 text-gray-700">
Displaying the {results.length} most similar results:
</p>
<SearchResultsTable
results={results}
sortField={sortField}
sortDirection={sortDirection}
onSort={sortResults}
onSort={handleSort}
/>
</div>
</div>
Expand Down
101 changes: 78 additions & 23 deletions notebooks/main.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,23 @@
"from pypi_llm.config import Config\n",
"from pypi_llm.data.description_cleaner import DescriptionCleaner, CLEANING_FAILED\n",
"from pypi_llm.data.reader import DataReader\n",
"from sentence_transformers import SentenceTransformer\n",
"from pypi_llm.vector_database import VectorDatabaseInterface\n",
"\n",
"load_dotenv()\n",
"config = Config()\n",
"\n",
"df = DataReader(config.DATA_DIR).read()\n",
"df = DescriptionCleaner().clean(df, \"description\", \"description_cleaned\")\n",
"df = df.filter(~pl.col(\"description_cleaned\").is_null())\n",
"df = df.filter(pl.col(\"description_cleaned\")!=CLEANING_FAILED)"
"# Load dataset and model\n",
"df = pl.read_csv(config.DATA_DIR / config.PROCESSED_DATASET_CSV_NAME)\n",
"model = SentenceTransformer(config.EMBEDDINGS_MODEL_NAME)\n",
"\n",
"# Initialize vector database interface\n",
"vector_database_interface = VectorDatabaseInterface(\n",
" pinecone_token=config.PINECONE_TOKEN,\n",
" pinecone_index_name=config.PINECONE_INDEX_NAME,\n",
" embeddings_model=model,\n",
" pinecone_namespace=config.PINECONE_NAMESPACE,\n",
")"
]
},
{
Expand All @@ -51,39 +60,85 @@
"metadata": {},
"outputs": [],
"source": [
"with pl.Config(fmt_str_lengths=1000):\n",
"with pl.Config(fmt_str_lengths=100):\n",
" display(df.head(10))"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "053c9cf1-9f79-4b98-bcc9-85b6b676da84",
"id": "bf393f0c-92c6-4d4a-bd97-d3ea7ebf2b80",
"metadata": {},
"outputs": [],
"source": [
"from sentence_transformers import SentenceTransformer\n",
"model = SentenceTransformer(config.EMBEDDINGS_MODEL)\n",
"embeddings = model.encode(query)\n",
"\n",
"from pinecone import Pinecone, Index\n",
"pc = Pinecone(api_key=config.PINECONE_TOKEN)\n",
"index = pc.Index(config.PINECONE_INDEX_NAME)\n",
"\n",
"matches = index.query(\n",
" namespace=\"ns1\",\n",
" vector=embeddings.tolist(),\n",
" top_k=50,\n",
" include_values=False\n",
"query = \"find unused packages\""
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "07ebd4fd-a0b9-4958-8325-bdff4be45a66",
"metadata": {},
"outputs": [],
"source": [
"df_matches = vector_database_interface.find_similar(query, top_k=100)\n",
"df_matches = df_matches.join(df, how=\"left\", on=\"name\")\n",
"df_matches = df_matches.sort(\"similarity\", descending=True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "fa071203-a3cd-4e80-a7b7-0ac7562bef8d",
"metadata": {},
"outputs": [],
"source": [
"df_matches"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "8b7b28e7-495c-44db-a939-dfa3e2c45159",
"metadata": {},
"outputs": [],
"source": [
"# Rank the columns\n",
"df_matches = df_matches.with_columns(\n",
" rank_similarity=pl.col(\"similarity\").rank(\"dense\", descending=False),\n",
" rank_weekly_downloads=pl.col(\"weekly_downloads\").rank(\"dense\", descending=False)\n",
")\n",
"\n",
"df_matches = pl.from_dicts([{'name' : x['id'], 'similarity': x['score']} for x in matches['matches']])\n",
"df_matches = df_matches.with_columns(\n",
" normalized_similarity=(pl.col(\"rank_similarity\") - 1) / (df_matches['rank_similarity'].max() - 1),\n",
" normalized_weekly_downloads=(pl.col(\"rank_weekly_downloads\") - 1) / (df_matches['rank_weekly_downloads'].max() - 1)\n",
")\n",
"\n",
"df_matches = df_matches.join(df, how = 'left', on = 'name')\n",
"df_matches = df_matches.with_columns(\n",
" score=0.5 * pl.col(\"normalized_similarity\") + 0.5 * pl.col(\"normalized_weekly_downloads\")\n",
")\n",
"\n",
"df_matches.sort('weekly_downloads', descending=True)\n",
"\n"
"# Sort the DataFrame by the combined score in descending order\n",
"df_matches = df_matches.sort(\"score\", descending=True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d5465cec-c717-4fc5-aa55-c4c7dc9e79cf",
"metadata": {},
"outputs": [],
"source": [
"df_matches.sort(\"score\", descending=True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "4384f057-8eaf-431d-a31a-f4f7e203ed35",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
Expand Down
18 changes: 9 additions & 9 deletions pypi_llm/api/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,16 @@
from sentence_transformers import SentenceTransformer

from pypi_llm.config import Config
from pypi_llm.utils.score_calculator import calculate_score
from pypi_llm.vector_database import VectorDatabaseInterface

app = FastAPI()

# Load environment variables
load_dotenv()
config = Config()

# Setup CORS
origins = [
"http://localhost:3000",
# Add other origins if needed
]

app.add_middleware(
Expand All @@ -28,11 +26,9 @@
allow_headers=["*"],
)

# Load dataset and model
df = pl.read_csv(config.DATA_DIR / config.PROCESSED_DATASET_CSV_NAME)
model = SentenceTransformer(config.EMBEDDINGS_MODEL_NAME)

# Initialize vector database interface
vector_database_interface = VectorDatabaseInterface(
pinecone_token=config.PINECONE_TOKEN,
pinecone_index_name=config.PINECONE_INDEX_NAME,
Expand All @@ -41,9 +37,9 @@
)


# Define request and response models
class QueryModel(BaseModel):
query: str
top_k: int = 30


class Match(BaseModel):
Expand All @@ -57,10 +53,14 @@ class SearchResponse(BaseModel):
matches: list[Match]


# Define search endpoint
@app.post("/search/", response_model=SearchResponse)
async def search(query: QueryModel):
df_matches = vector_database_interface.find_similar(query.query, top_k=50)
df_matches = vector_database_interface.find_similar(query.query, top_k=query.top_k * 2)
df_matches = df_matches.join(df, how="left", on="name")
df_matches = df_matches.sort("similarity", descending=True)

df_matches = calculate_score(df_matches)
df_matches = df_matches.sort("score", descending=True)
df_matches = df_matches.head(query.top_k)

print("sending")
return SearchResponse(matches=df_matches.to_dicts())
2 changes: 1 addition & 1 deletion pypi_llm/scripts/upsert_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,6 @@
)

df = df.with_columns(
summary_and_description_cleaned=pl.concat_str(pl.col("summary"), pl.lit(" "), pl.col("description_cleaned"))
summary_and_description_cleaned=pl.concat_str(pl.col("summary"), pl.lit(" - "), pl.col("description_cleaned"))
)
vector_database_interface.upsert_polars(df, key_column="name", text_column="summary_and_description_cleaned")
34 changes: 34 additions & 0 deletions pypi_llm/utils/score_calculator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import polars as pl


def calculate_score(df: pl.DataFrame, weight_similarity=0.5, weight_weekly_downloads=0.5) -> pl.DataFrame:
"""
Calculate a combined score based on similarity and weekly downloads.
The function ranks the similarity and weekly downloads, normalizes these ranks to a [0, 1] scale,
and then computes a combined score based on the provided weights for similarity and weekly downloads.
The DataFrame is sorted by the combined score in descending order.
Args:
df (pl.DataFrame): DataFrame containing 'similarity' and 'weekly_downloads' columns.
weight_similarity (float): Weight for the similarity score in the combined score calculation. Default is 0.5.
weight_weekly_downloads (float): Weight for the weekly downloads score in the combined score calculation. Default is 0.5.
"""
df = df.with_columns(
rank_similarity=pl.col("similarity").rank("dense", descending=False),
rank_weekly_downloads=pl.col("weekly_downloads").rank("dense", descending=False),
)

df = df.with_columns(
normalized_similarity=(pl.col("rank_similarity") - 1) / (df["rank_similarity"].max() - 1),
normalized_weekly_downloads=(pl.col("rank_weekly_downloads") - 1) / (df["rank_weekly_downloads"].max() - 1),
)

df = df.with_columns(
score=weight_similarity * pl.col("normalized_similarity")
+ weight_weekly_downloads * pl.col("normalized_weekly_downloads")
)

df = df.sort("score", descending=True)
return df

0 comments on commit 3ec5cc2

Please sign in to comment.