-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
This uses asyncpg directly. This removes the tags from the methods, and changes the names to be more descriptive. The intent is to have a single `service` generated from the OpenAPI specification that includes methods like `retrieve_chunks` rather than `chunks.retrieve(...)`.
- Loading branch information
1 parent
1b3da43
commit 3277785
Showing
19 changed files
with
599 additions
and
175 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,14 @@ | ||
# Ignore everything | ||
** | ||
|
||
# Include (don't ignore) the application code | ||
!app | ||
!./pyproject.toml | ||
!./poetry.lock | ||
**/__pycache__ | ||
|
||
# Re-ignore pycache within `app`. | ||
**/__pycache__ | ||
|
||
# Include (don't ignore) the migrations. | ||
!migrations/*.sql | ||
!yoyo.ini |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
|
||
from pydantic import BaseModel, ConfigDict, TypeAdapter | ||
|
||
class Collection(BaseModel): | ||
model_config=ConfigDict(from_attributes=True) | ||
|
||
"""A collection of indexed documents.""" | ||
id: int | ||
"""The ID of the collection.""" | ||
|
||
name: str | ||
"""The name of the collection.""" | ||
|
||
collection_validator = TypeAdapter(Collection) | ||
|
||
|
||
class CollectionCreate(BaseModel): | ||
"""The request to create a collection.""" | ||
|
||
name: str | ||
"""The name of the collection.""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,35 +1,37 @@ | ||
from typing import Annotated, List | ||
|
||
from fastapi import APIRouter, Path | ||
from sqlmodel import Session, select | ||
from pydantic import parse_obj_as | ||
|
||
from app.common.schema import Collection, EngineDep | ||
from app.common.db import PgConnectionDep | ||
from app.collections.models import * | ||
|
||
router = APIRouter(tags=["collections"], prefix="/collections") | ||
router = APIRouter(prefix="/collections") | ||
|
||
|
||
@router.put("/") | ||
async def add(engine: EngineDep, collection: Collection) -> Collection: | ||
async def add_collection(conn: PgConnectionDep, collection: CollectionCreate) -> Collection: | ||
"""Create a collection.""" | ||
with Session(engine) as session: | ||
session.add(collection) | ||
session.commit() | ||
session.refresh(collection) | ||
return collection | ||
result = await conn.fetchrow(""" | ||
INSERT INTO collection (name) VALUES ($1) | ||
RETURNING id, name | ||
""", | ||
collection.name) | ||
return Collection.model_validate(dict(result)) | ||
|
||
|
||
@router.get("/") | ||
async def list(engine: EngineDep) -> List[Collection]: | ||
async def list_collections(conn: PgConnectionDep) -> List[Collection]: | ||
"""List collections.""" | ||
with Session(engine) as session: | ||
return session.exec(select(Collection)).all() | ||
results = await conn.fetch("SELECT id, name FROM collection") | ||
return [Collection.model_validate(dict(result)) for result in results] | ||
|
||
|
||
PathCollectionId = Annotated[int, Path(..., description="The collection ID.")] | ||
|
||
|
||
@router.get("/{id}") | ||
async def get(id: PathCollectionId, engine: EngineDep) -> Collection: | ||
async def get_collection(id: PathCollectionId, conn: PgConnectionDep) -> Collection: | ||
"""Get a specific collection.""" | ||
with Session(engine) as session: | ||
return session.get(Collection, id) | ||
result = await conn.fetchrow("SELECT id, name FROM collection WHERE id = $1", id) | ||
return Collection.model_validate(dict(result)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
import contextlib | ||
from enum import Enum | ||
from typing import Annotated, AsyncIterator, Optional | ||
from uuid import UUID | ||
import asyncpg | ||
|
||
from fastapi import Depends, Request | ||
|
||
@contextlib.asynccontextmanager | ||
async def create_pool(dsn: str) -> AsyncIterator[asyncpg.Pool]: | ||
""" | ||
Create a postgres connection pool. | ||
Arguments: | ||
- dsn: Connection arguments specified using as a single string in | ||
the following format: | ||
`postgres://user:pass@host:port/database?option=value`. | ||
""" | ||
pool = await asyncpg.create_pool(dsn) | ||
yield pool | ||
pool.close() | ||
|
||
def _pg_pool(request: Request) -> asyncpg.Pool: | ||
return request.state.pg_pool | ||
|
||
PgPoolDep = Annotated[asyncpg.Pool, Depends(_pg_pool)] | ||
|
||
async def _pg_connection(pool: PgPoolDep) -> asyncpg.Connection: | ||
async with pool.acquire() as connection: | ||
yield connection | ||
|
||
PgConnectionDep = Annotated[asyncpg.Connection, Depends(_pg_connection)] |
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
from enum import Enum | ||
from typing import Optional | ||
|
||
from pydantic import BaseModel | ||
|
||
|
||
class IngestState(Enum): | ||
PENDING = "pending" | ||
"""Document is pending ingestion.""" | ||
|
||
INGESTED = "ingested" | ||
"""Document has been ingested.""" | ||
|
||
FAILED = "failed" | ||
"""Document failed to be ingested. See `ingest_errors` for details.""" | ||
|
||
class Document(BaseModel): | ||
"""Schema for documents in the SQL DB.""" | ||
|
||
id: Optional[int] = None | ||
collection_id: int | ||
|
||
url: str | ||
|
||
ingest_state: Optional[IngestState] = None | ||
ingest_error: Optional[str] = None |
Oops, something went wrong.