Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: initial pgvector indices #11

Merged
merged 2 commits into from
Jan 24, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions .dockerignore
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,4 @@
**/__pycache__

# Include (don't ignore) the migrations.
!migrations/*.sql
!yoyo.ini
!migrations/*.sql
4 changes: 3 additions & 1 deletion app/chunks/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@


@router.post("/retrieve")
async def retrieve_chunks(store: StoreDep, request: RetrieveRequest) -> RetrieveResponse:
async def retrieve_chunks(
store: StoreDep, request: RetrieveRequest
) -> RetrieveResponse:
"""Retrieve chunks based on a given query."""

from llama_index.response_synthesizers import ResponseMode
Expand Down
18 changes: 16 additions & 2 deletions app/collections/models.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,21 @@
from dataclasses import dataclass
from enum import Enum

from pydantic import BaseModel, ConfigDict, TypeAdapter


@dataclass
class EmbeddingDataMixin:
dimensions: int


class EmbeddingModel(EmbeddingDataMixin, Enum):
openai_text_embedding_ada_002 = 1536
hf_baai_bge_small_en = 384


class Collection(BaseModel):
model_config=ConfigDict(from_attributes=True)
model_config = ConfigDict(from_attributes=True)

"""A collection of indexed documents."""
id: int
Expand All @@ -11,11 +24,12 @@ class Collection(BaseModel):
name: str
"""The name of the collection."""


collection_validator = TypeAdapter(Collection)


class CollectionCreate(BaseModel):
"""The request to create a collection."""

name: str
"""The name of the collection."""
"""The name of the collection."""
13 changes: 8 additions & 5 deletions app/collections/router.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,25 @@
from typing import Annotated, List

from fastapi import APIRouter, Path
from pydantic import parse_obj_as

from app.collections.models import Collection, CollectionCreate
from app.common.db import PgConnectionDep
from app.collections.models import *

router = APIRouter(prefix="/collections")


@router.put("/")
async def add_collection(conn: PgConnectionDep, collection: CollectionCreate) -> Collection:
async def add_collection(
conn: PgConnectionDep, collection: CollectionCreate
) -> Collection:
"""Create a collection."""
result = await conn.fetchrow("""
result = await conn.fetchrow(
"""
INSERT INTO collection (name) VALUES ($1)
RETURNING id, name
""",
collection.name)
collection.name,
)
return Collection.model_validate(dict(result))


Expand Down
24 changes: 18 additions & 6 deletions app/common/db.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
import contextlib
from enum import Enum
from typing import Annotated, AsyncIterator, Optional
from uuid import UUID
import asyncpg
from typing import Annotated, AsyncIterator

import asyncpg
from fastapi import Depends, Request


@contextlib.asynccontextmanager
async def create_pool(dsn: str) -> AsyncIterator[asyncpg.Pool]:
"""
Expand All @@ -16,17 +15,30 @@ async def create_pool(dsn: str) -> AsyncIterator[asyncpg.Pool]:
the following format:
`postgres://user:pass@host:port/database?option=value`.
"""
pool = await asyncpg.create_pool(dsn)

async def init_pool(conn: asyncpg.Connection):
# Need the extension before we register, so do this here.
await conn.execute("CREATE EXTENSION IF NOT EXISTS vector")

from pgvector.asyncpg import register_vector

await register_vector(conn)

pool = await asyncpg.create_pool(dsn, init=init_pool)
yield pool
pool.close()


def _pg_pool(request: Request) -> asyncpg.Pool:
return request.state.pg_pool


PgPoolDep = Annotated[asyncpg.Pool, Depends(_pg_pool)]


async def _pg_connection(pool: PgPoolDep) -> asyncpg.Connection:
async with pool.acquire() as connection:
yield connection

PgConnectionDep = Annotated[asyncpg.Connection, Depends(_pg_connection)]

PgConnectionDep = Annotated[asyncpg.Connection, Depends(_pg_connection)]
2 changes: 1 addition & 1 deletion app/config.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Any, Optional

from fastapi.routing import APIRoute
from pydantic import RedisDsn, PostgresDsn, ValidationInfo, field_validator
from pydantic import PostgresDsn, RedisDsn, ValidationInfo, field_validator
from pydantic_core import Url
from pydantic_settings import BaseSettings

Expand Down
3 changes: 2 additions & 1 deletion app/documents/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ class IngestState(Enum):
FAILED = "failed"
"""Document failed to be ingested. See `ingest_errors` for details."""


class Document(BaseModel):
"""Schema for documents in the SQL DB."""

Expand All @@ -23,4 +24,4 @@ class Document(BaseModel):
url: str

ingest_state: Optional[IngestState] = None
ingest_error: Optional[str] = None
ingest_error: Optional[str] = None
50 changes: 33 additions & 17 deletions app/documents/router.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,12 @@
from typing import Annotated, List
import asyncpg

import asyncpg
from fastapi import APIRouter, BackgroundTasks, Body, HTTPException, Path, status
from loguru import logger
from sqlalchemy import Engine
from sqlmodel import Session, select
from app.collections.router import PathCollectionId

from app.collections.router import PathCollectionId
from app.common.db import PgConnectionDep, PgPoolDep
from app.documents.models import *
from app.documents.models import Document
from app.ingest.extract import extract
from app.ingest.extract.source import ExtractSource
from app.ingest.store import Store, StoreDep
Expand Down Expand Up @@ -43,11 +41,15 @@ async def ingest_document(id: int, store: Store, pg_pool: asyncpg.Pool):
nodes = await store.ingestion_pipeline.arun(documents=documents)
logger.debug("Done. Inserted {} nodes", len(nodes))

await conn.execute("""
await conn.execute(
"""
UPDATE document
SET ingest_state = 'ingested', ingest_error = NULL
WHERE id = $1
""", id)
""",
id,
)


@router.put("/")
async def add_document(
Expand All @@ -61,37 +63,51 @@ async def add_document(

row = None
async with pg_pool.acquire() as conn:
row = await conn.fetchrow("""
INSERT INTO document (collection_id, url, ingest_state) VALUES ($1, $2, 'pending')
row = await conn.fetchrow(
"""
INSERT INTO document (collection_id, url, ingest_state)
VALUES ($1, $2, 'pending')
RETURNING id, collection_id, url, ingest_state, ingest_error
""", collection_id, url)
""",
collection_id,
url,
)

document = Document.model_validate(dict(row))
background.add_task(ingest_document, document.id, store, pg_pool)
return document


PathDocumentId = Annotated[int, Path(..., description="The document ID.")]


@router.get("/")
async def list_documents(collection_id: PathCollectionId, conn: PgConnectionDep) -> List[Document]:
async def list_documents(
collection_id: PathCollectionId, conn: PgConnectionDep
) -> List[Document]:
"""List documents."""
# TODO: Test
results = await conn.fetch("""
results = await conn.fetch(
"""
SELECT id, collection_id, url, ingest_state, ingest_error
FROM document WHERE collection_id = $1
""", collection_id)
""",
collection_id,
)
return [Document.model_validate(dict(result)) for result in results]


@router.get("/{id}")
async def get_document(
conn: PgConnectionDep,
collection_id: PathCollectionId,
id: PathDocumentId
conn: PgConnectionDep, collection_id: PathCollectionId, id: PathDocumentId
) -> Document:
# TODO: Test / return not found?
result = await conn.fetchrow(
"""
SELECT id, collection_id, url, ingest_state, ingest_error
FROM document WHERE id = $1 AND collection_id = $2
""", id, collection_id)
""",
id,
collection_id,
)
return Document.model_validate(dict(result))
41 changes: 23 additions & 18 deletions app/main.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
import contextlib
from typing import AsyncIterator, TypedDict

import asyncpg
from fastapi import FastAPI
from sqlalchemy import Engine
from sqlmodel import SQLModel, create_engine
from loguru import logger

from app.collections.models import EmbeddingModel
from app.common import db
from app.config import app_configs, settings
from app.ingest.store import Store
Expand All @@ -14,33 +13,39 @@

class State(TypedDict):
store: Store
db: Engine
pg_pool: asyncpg.Pool


@contextlib.asynccontextmanager
async def lifespan(_app: FastAPI) -> AsyncIterator[State]:
"""Function creating instances used during the lifespan of the service."""

# if settings.APPLY_MIGRATIONS:
# from yoyo import get_backend, read_migrations
# backend = get_backend(settings.DB.unicode_string())
# migrations = read_migrations('migrations')
# with backend.lock():
# outstanding = backend.to_apply(migrations)

# logger.info("Applying {} migrations", len(outstanding))

# # Apply any outstanding migrations
# backend.apply_migrations(outstanding)

# logger.info("Done applying migrations.")

# TODO: Look at https://gist.github.com/mattbillenstein/270a4d44cbdcb181ac2ed58526ae137d
# for simple migration scripts.
async with db.create_pool(settings.DB.unicode_string()) as pg_pool:
if settings.APPLY_MIGRATIONS:
async with pg_pool.acquire() as conn:
with open("migrations/0001_schema.sql") as schema_file:
schema = schema_file.read()
await conn.execute(schema)

# TODO: create indices with different distance metrics and
# either allow configuring that, or setting a default for each
# embedding model? We'd need to change the `vector_cosine_ops`
# to `vector_{l2,ip}_ops` for l2 distance or inner-product, and
# use different operators when querying.
bjchambers marked this conversation as resolved.
Show resolved Hide resolved
index_creation = [
f"""
CREATE INDEX IF NOT EXISTS {emb.name}_index
ON embedding
USING hnsw ((embedding::vector({emb.dimensions})) vector_cosine_ops)
WHERE (embedding_model = '{emb.name}');
"""
for emb in EmbeddingModel
]
print(index_creation)
await conn.execute("\n\n".join(index_creation))

state = {
"store": Store(),
"pg_pool": pg_pool,
Expand Down
2 changes: 1 addition & 1 deletion app/routes.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from fastapi import APIRouter

from app.chunks.router import router as chunks_router
from app.collections.router import router as collections_router
from app.documents.router import router as documents_router
from app.chunks.router import router as chunks_router

api_router = APIRouter(prefix="/api")

Expand Down
9 changes: 8 additions & 1 deletion migrations/0001_schema.sql
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
-- Apply the base schema.
CREATE TYPE embedding_model as ENUM (
'openai_text_embedding_ada_002',
'hf_baai_bge_small_en'
);

CREATE TABLE collection (
id SERIAL NOT NULL,
name VARCHAR NOT NULL,
text_embedding_model embedding_model NOT NULL,

PRIMARY KEY (id)
);
Expand Down Expand Up @@ -82,7 +87,9 @@ CREATE TYPE embedding_kind AS ENUM (
CREATE TABLE embedding(
id SERIAL NOT NULL,

chunk_id INTEGER,
embedding vector NOT NULL,
embedding_model embedding_model NOT NULL,
chunk_id INTEGER NOT NULL,

key_text VARCHAR,

Expand Down
Loading