Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: initial pgvector indices #11

Merged
merged 2 commits into from
Jan 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions .dockerignore
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,4 @@
**/__pycache__

# Include (don't ignore) the migrations.
!migrations/*.sql
!yoyo.ini
!migrations/*.sql
4 changes: 3 additions & 1 deletion app/chunks/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@


@router.post("/retrieve")
async def retrieve_chunks(store: StoreDep, request: RetrieveRequest) -> RetrieveResponse:
async def retrieve_chunks(
store: StoreDep, request: RetrieveRequest
) -> RetrieveResponse:
"""Retrieve chunks based on a given query."""

from llama_index.response_synthesizers import ResponseMode
Expand Down
48 changes: 44 additions & 4 deletions app/collections/models.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,25 @@
from enum import Enum

from pydantic import BaseModel, ConfigDict, Field, TypeAdapter


class DistanceMetric(Enum):
cosine = "cosine"
inner_product = "ip"
l2 = "l2"

def vector_ops(self) -> str:
match self:
case DistanceMetric.cosine:
return "vector_cosine_ops"
case DistanceMetric.inner_product:
return "vector_ip_ops"
case DistanceMetric.l2:
return "vector_l2_ops"

from pydantic import BaseModel, ConfigDict, TypeAdapter

class Collection(BaseModel):
model_config=ConfigDict(from_attributes=True)
model_config = ConfigDict(from_attributes=True)

"""A collection of indexed documents."""
id: int
Expand All @@ -11,11 +28,34 @@ class Collection(BaseModel):
name: str
"""The name of the collection."""

text_embedding_model: str
"""The name of the embedding model.

NOTE: Changing embedding models is not currently supported.
"""

text_distance_metric: DistanceMetric = DistanceMetric.cosine
"""The distance metric to use on the text embedding.

NOTE: Changing distance metrics is not currently supported."""


collection_validator = TypeAdapter(Collection)


class CollectionCreate(BaseModel):
"""The request to create a collection."""

name: str
"""The name of the collection."""
name: str = Field(examples=["my_collection"])
"""The name of the collection."""

text_embedding_model: str = Field(examples=["openai:text-embedding-ada-002", "hf:BAAI/bge-small-en"])
"""The name of the embedding model.

NOTE: Changing embedding models is not currently supported.
"""

text_distance_metric: DistanceMetric = DistanceMetric.cosine
"""The distance metric to use on the text embedding.

NOTE: Changing distance metrics is not currently supported."""
64 changes: 54 additions & 10 deletions app/collections/router.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,66 @@
from typing import Annotated, List

from fastapi import APIRouter, Path
from pydantic import parse_obj_as

from app.collections.models import Collection, CollectionCreate
from app.common.db import PgConnectionDep
from app.collections.models import *

router = APIRouter(prefix="/collections")


def get_dimensions(model: str) -> int:
# TODO: Consider instantiating the model and applying it to a string
# to determine the dimensions. This would make it easier to support
# new models.
match model:
case "openai:text-embedding-ada-002":
return 1536
case "hf:BAAI/bge-small-en":
return 384
case _:
raise ValueError(f"Unsupported model '{model}'")


@router.put("/")
async def add_collection(conn: PgConnectionDep, collection: CollectionCreate) -> Collection:
async def add_collection(
conn: PgConnectionDep, collection: CollectionCreate
) -> Collection:
"""Create a collection."""
result = await conn.fetchrow("""
INSERT INTO collection (name) VALUES ($1)
RETURNING id, name
""",
collection.name)
dimensions = get_dimensions(collection.text_embedding_model)
async with conn.transaction():
result = await conn.fetchrow(
"""
INSERT INTO collection (name, text_embedding_model, text_distance_metric)
VALUES ($1, $2, $3)
RETURNING id, name, text_embedding_model, text_distance_metric
""",
collection.name,
collection.text_embedding_model,
collection.text_distance_metric.value,
)

# Create a separate *partial* index on each collection.
# This allows us to define different dimensions (and vector distance) for
# each collection.
#
# https://github.com/pgvector/pgvector?tab=readme-ov-file#can-i-store-vectors-with-different-dimensions-in-the-same-column
id = result["id"]
vector_ops = collection.text_distance_metric.vector_ops()
await conn.execute(
f"""
CREATE INDEX embedding_collection_{id}_index
ON embedding
USING hnsw ((embedding::vector({dimensions})) {vector_ops})
WHERE collection_id = {id}
"""
)
return Collection.model_validate(dict(result))


@router.get("/")
async def list_collections(conn: PgConnectionDep) -> List[Collection]:
"""List collections."""
results = await conn.fetch("SELECT id, name FROM collection")
results = await conn.fetch("SELECT id, name, text_embedding_model FROM collection")
return [Collection.model_validate(dict(result)) for result in results]


Expand All @@ -33,5 +70,12 @@ async def list_collections(conn: PgConnectionDep) -> List[Collection]:
@router.get("/{id}")
async def get_collection(id: PathCollectionId, conn: PgConnectionDep) -> Collection:
"""Get a specific collection."""
result = await conn.fetchrow("SELECT id, name FROM collection WHERE id = $1", id)
result = await conn.fetchrow(
"""
SELECT id, name, text_embedding_model
FROM collection
WHERE id = $1
""",
id,
)
return Collection.model_validate(dict(result))
24 changes: 18 additions & 6 deletions app/common/db.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
import contextlib
from enum import Enum
from typing import Annotated, AsyncIterator, Optional
from uuid import UUID
import asyncpg
from typing import Annotated, AsyncIterator

import asyncpg
from fastapi import Depends, Request


@contextlib.asynccontextmanager
async def create_pool(dsn: str) -> AsyncIterator[asyncpg.Pool]:
"""
Expand All @@ -16,17 +15,30 @@ async def create_pool(dsn: str) -> AsyncIterator[asyncpg.Pool]:
the following format:
`postgres://user:pass@host:port/database?option=value`.
"""
pool = await asyncpg.create_pool(dsn)

async def init_pool(conn: asyncpg.Connection):
# Need the extension before we register, so do this here.
await conn.execute("CREATE EXTENSION IF NOT EXISTS vector")

from pgvector.asyncpg import register_vector

await register_vector(conn)

pool = await asyncpg.create_pool(dsn, init=init_pool)
yield pool
pool.close()


def _pg_pool(request: Request) -> asyncpg.Pool:
return request.state.pg_pool


PgPoolDep = Annotated[asyncpg.Pool, Depends(_pg_pool)]


async def _pg_connection(pool: PgPoolDep) -> asyncpg.Connection:
async with pool.acquire() as connection:
yield connection

PgConnectionDep = Annotated[asyncpg.Connection, Depends(_pg_connection)]

PgConnectionDep = Annotated[asyncpg.Connection, Depends(_pg_connection)]
2 changes: 1 addition & 1 deletion app/config.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Any, Optional

from fastapi.routing import APIRoute
from pydantic import RedisDsn, PostgresDsn, ValidationInfo, field_validator
from pydantic import PostgresDsn, RedisDsn, ValidationInfo, field_validator
from pydantic_core import Url
from pydantic_settings import BaseSettings

Expand Down
3 changes: 2 additions & 1 deletion app/documents/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ class IngestState(Enum):
FAILED = "failed"
"""Document failed to be ingested. See `ingest_errors` for details."""


class Document(BaseModel):
"""Schema for documents in the SQL DB."""

Expand All @@ -23,4 +24,4 @@ class Document(BaseModel):
url: str

ingest_state: Optional[IngestState] = None
ingest_error: Optional[str] = None
ingest_error: Optional[str] = None
50 changes: 33 additions & 17 deletions app/documents/router.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,12 @@
from typing import Annotated, List
import asyncpg

import asyncpg
from fastapi import APIRouter, BackgroundTasks, Body, HTTPException, Path, status
from loguru import logger
from sqlalchemy import Engine
from sqlmodel import Session, select
from app.collections.router import PathCollectionId

from app.collections.router import PathCollectionId
from app.common.db import PgConnectionDep, PgPoolDep
from app.documents.models import *
from app.documents.models import Document
from app.ingest.extract import extract
from app.ingest.extract.source import ExtractSource
from app.ingest.store import Store, StoreDep
Expand Down Expand Up @@ -43,11 +41,15 @@ async def ingest_document(id: int, store: Store, pg_pool: asyncpg.Pool):
nodes = await store.ingestion_pipeline.arun(documents=documents)
logger.debug("Done. Inserted {} nodes", len(nodes))

await conn.execute("""
await conn.execute(
"""
UPDATE document
SET ingest_state = 'ingested', ingest_error = NULL
WHERE id = $1
""", id)
""",
id,
)


@router.put("/")
async def add_document(
Expand All @@ -61,37 +63,51 @@ async def add_document(

row = None
async with pg_pool.acquire() as conn:
row = await conn.fetchrow("""
INSERT INTO document (collection_id, url, ingest_state) VALUES ($1, $2, 'pending')
row = await conn.fetchrow(
"""
INSERT INTO document (collection_id, url, ingest_state)
VALUES ($1, $2, 'pending')
RETURNING id, collection_id, url, ingest_state, ingest_error
""", collection_id, url)
""",
collection_id,
url,
)

document = Document.model_validate(dict(row))
background.add_task(ingest_document, document.id, store, pg_pool)
return document


PathDocumentId = Annotated[int, Path(..., description="The document ID.")]


@router.get("/")
async def list_documents(collection_id: PathCollectionId, conn: PgConnectionDep) -> List[Document]:
async def list_documents(
collection_id: PathCollectionId, conn: PgConnectionDep
) -> List[Document]:
"""List documents."""
# TODO: Test
results = await conn.fetch("""
results = await conn.fetch(
"""
SELECT id, collection_id, url, ingest_state, ingest_error
FROM document WHERE collection_id = $1
""", collection_id)
""",
collection_id,
)
return [Document.model_validate(dict(result)) for result in results]


@router.get("/{id}")
async def get_document(
conn: PgConnectionDep,
collection_id: PathCollectionId,
id: PathDocumentId
conn: PgConnectionDep, collection_id: PathCollectionId, id: PathDocumentId
) -> Document:
# TODO: Test / return not found?
result = await conn.fetchrow(
"""
SELECT id, collection_id, url, ingest_state, ingest_error
FROM document WHERE id = $1 AND collection_id = $2
""", id, collection_id)
""",
id,
collection_id,
)
return Document.model_validate(dict(result))
23 changes: 5 additions & 18 deletions app/main.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
import contextlib
from typing import AsyncIterator, TypedDict

import asyncpg
from fastapi import FastAPI
from sqlalchemy import Engine
from sqlmodel import SQLModel, create_engine
from loguru import logger

from app.common import db
from app.config import app_configs, settings
Expand All @@ -14,33 +12,22 @@

class State(TypedDict):
store: Store
db: Engine
pg_pool: asyncpg.Pool


@contextlib.asynccontextmanager
async def lifespan(_app: FastAPI) -> AsyncIterator[State]:
"""Function creating instances used during the lifespan of the service."""

# if settings.APPLY_MIGRATIONS:
# from yoyo import get_backend, read_migrations
# backend = get_backend(settings.DB.unicode_string())
# migrations = read_migrations('migrations')
# with backend.lock():
# outstanding = backend.to_apply(migrations)

# logger.info("Applying {} migrations", len(outstanding))

# # Apply any outstanding migrations
# backend.apply_migrations(outstanding)

# logger.info("Done applying migrations.")

# TODO: Look at https://gist.github.com/mattbillenstein/270a4d44cbdcb181ac2ed58526ae137d
# for simple migration scripts.
async with db.create_pool(settings.DB.unicode_string()) as pg_pool:
if settings.APPLY_MIGRATIONS:
async with pg_pool.acquire() as conn:
with open("migrations/0001_schema.sql") as schema_file:
schema = schema_file.read()
await conn.execute(schema)

state = {
"store": Store(),
"pg_pool": pg_pool,
Expand Down
Loading