Skip to content

Commit

Permalink
feat: Add a script to extract openapi.yaml (#44)
Browse files Browse the repository at this point in the history
* feat: Add a script to extrat openapi.yaml

This adds a script that we can run with:

```sh
poetry run python scripts/extract_openapi.py dewy.main:app
```

This also modifies the options to eliminate the LLM/embedding
environment variables (not used) and make the DB optional (failing
the CRUD methods if invoked).

This is part of #43.

* add poe tasks

* tasks; lint
  • Loading branch information
bjchambers authored Jan 30, 2024
1 parent c4af67b commit 1e3d9ba
Show file tree
Hide file tree
Showing 8 changed files with 140 additions and 129 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@
__pycache__
.env
.vscode/
openapi.yaml
5 changes: 4 additions & 1 deletion dewy/common/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,10 @@ async def init_pool(conn: asyncpg.Connection):


def _pg_pool(request: Request) -> asyncpg.Pool:
return request.state.pg_pool
pg_pool = request.state.pg_pool
if pg_pool is None:
raise ValueError("DB not configured. Unable to get pool.")
return pg_pool


PgPoolDep = Annotated[asyncpg.Pool, Depends(_pg_pool)]
Expand Down
71 changes: 6 additions & 65 deletions dewy/config.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
from typing import Any, Optional

from fastapi.routing import APIRoute
from pydantic import PostgresDsn, RedisDsn, ValidationInfo, field_validator
from pydantic_core import Url
from pydantic import PostgresDsn
from pydantic_settings import BaseSettings, SettingsConfigDict

from dewy.constants import Environment
Expand All @@ -21,8 +20,11 @@ class Config(BaseSettings):
SERVE_ADMIN_UI: bool = True
"""If true, serve the admin UI."""

DB: PostgresDsn
"""The Postgres database to connect to."""
DB: Optional[PostgresDsn] = None
"""The Postgres database to connect to.
If not provided, none of the CRUD methods will work.
"""

APPLY_MIGRATIONS: bool = True
"""Whether migrations should be applied to the database.
Expand All @@ -34,73 +36,12 @@ class Config(BaseSettings):
ENVIRONMENT: Environment = Environment.PRODUCTION
"""The environment the application is running in."""

REDIS: Optional[RedisDsn] = None
"""The Redis service to use for queueing, indexing and document storage."""

EMBEDDING_MODEL: str = ""
"""The embedding model to use.
This is a string of the form `<kind>[:<model>]` with the following options:
1. `local[:<path or repository>]` -- Run a model locally. If this is a path
it will attempt to load a model from that location. Otherwise, it should
be a Hugging Face repository from which to retrieve the model.
2. `openai[:<name>]` -- The named OpenAI model. `OPENAI_API_KEY` must be set.
3. `ollama:<name>` -- The named Ollama model. `OLLAMA_BASE_URL` must be set.
In each of these cases, you can omit the second part for the default model of the
given kind.
If unset, this will default to `"openai"` if an OpenAI API KEY is available and
otherwise will use `"local"`.
NOTE: Changing embedding models is not currently supported.
"""

LLM_MODEL: str = ""
"""The LLM model to use.
This is a string of the form `<kind>:<model>` with the following options:
1. `local[:<path or repository>]` -- Run a model locally. If this is a path
it will attempt to load a model from that location. Otherwise, it should
be a Hugging Face repository from which to retrieve the model.
2. `openai[:<name>]` -- The named OpenAI model. `OPENAI_API_KEY` must be set.
3. `ollama:<name>` -- The named Ollama model. `OLLAMA_BASE_URL` must be set.
In each of these cases, you can omit the second part for the default model of the
given kind.
If unset, this will default to `"openai"` if an OpenAI API KEY is available and
otherwise will use `"local"`.
"""

OPENAI_API_KEY: Optional[str] = None
""" The OpenAI API Key to use for OpenAI models.
This is required for using openai models.
"""

OLLAMA_BASE_URL: Optional[Url] = None
"""The Base URL for Ollama.
This is required for using ollama models.
"""

@field_validator("OLLAMA_BASE_URL")
def validate_ollama_base_url(cls, v, info: ValidationInfo):
MODELS = ["LLM_MODEL", "EMBEDDING_MODEL"]
if v is None:
for model in MODELS:
context = info.context
if context:
value = context.get(model, "")
if value.startswith("ollama"):
raise ValueError(
f"{info.field_name} must be set to use '{model}={value}'"
)
return v


settings = Config()

Expand Down
27 changes: 14 additions & 13 deletions dewy/main.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import contextlib
import os
from pathlib import Path
from typing import AsyncIterator, TypedDict
from typing import AsyncIterator, Optional, TypedDict

import asyncpg
import uvicorn
Expand All @@ -17,7 +17,7 @@


class State(TypedDict):
pg_pool: asyncpg.Pool
pg_pool: Optional[asyncpg.Pool]


# Resolve paths, independent of PWD
Expand All @@ -30,17 +30,18 @@ class State(TypedDict):
async def lifespan(_app: FastAPI) -> AsyncIterator[State]:
"""Function creating instances used during the lifespan of the service."""

# TODO: Look at https://gist.github.com/mattbillenstein/270a4d44cbdcb181ac2ed58526ae137d
# for simple migration scripts.
async with db.create_pool(settings.DB.unicode_string()) as pg_pool:
if settings.APPLY_MIGRATIONS:
async with pg_pool.acquire() as conn:
await apply_migrations(conn, migration_dir=migrations_path)

logger.info("Created database connection")
state = {
"pg_pool": pg_pool,
}
if settings.DB is not None:
async with db.create_pool(settings.DB.unicode_string()) as pg_pool:
if settings.APPLY_MIGRATIONS:
async with pg_pool.acquire() as conn:
await apply_migrations(conn, migration_dir=migrations_path)

logger.info("Created database connection")
state = State(pg_pool=pg_pool)
yield state
else:
logger.warn("No database configured. CRUD methods will fail.")
state = State(pg_pool=None)
yield state


Expand Down
58 changes: 9 additions & 49 deletions example_notebook.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -15,17 +15,9 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'id': 1, 'name': 'my_collection', 'text_embedding_model': 'openai:text-embedding-ada-002', 'text_distance_metric': 'cosine'}\n"
]
}
],
"outputs": [],
"source": [
"response = client.put(f\"/collections/\",\n",
" json = {\n",
Expand All @@ -39,17 +31,9 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'id': 1, 'collection_id': 1, 'url': 'https://arxiv.org/pdf/2305.14283.pdf', 'ingest_state': 'pending', 'ingest_error': None}\n"
]
}
],
"outputs": [],
"source": [
"# Add \"Query Rewriting for Retrieval-Augmented Large Language Models\"\n",
"response = client.put(f\"/documents/\",\n",
Expand All @@ -64,17 +48,9 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'id': 1, 'collection_id': 1, 'url': 'https://arxiv.org/pdf/2305.14283.pdf', 'ingest_state': 'ingested', 'ingest_error': None}\n"
]
}
],
"outputs": [],
"source": [
"# Report the status of the document ingestion.\n",
"response = client.get(f\"/documents/{document_id}\")\n",
Expand All @@ -83,25 +59,9 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": null,
"metadata": {},
"outputs": [
{
"ename": "JSONDecodeError",
"evalue": "Expecting value: line 1 column 1 (char 0)",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mJSONDecodeError\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn[7], line 9\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;66;03m# Retrieve 4 items with no summary.\u001b[39;00m\n\u001b[1;32m 2\u001b[0m results \u001b[38;5;241m=\u001b[39m client\u001b[38;5;241m.\u001b[39mpost(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m/chunks/retrieve\u001b[39m\u001b[38;5;124m\"\u001b[39m,\n\u001b[1;32m 3\u001b[0m json \u001b[38;5;241m=\u001b[39m {\n\u001b[1;32m 4\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mcollection_id\u001b[39m\u001b[38;5;124m\"\u001b[39m: collection_id,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 7\u001b[0m },\n\u001b[1;32m 8\u001b[0m timeout \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m)\n\u001b[0;32m----> 9\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[43mresults\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mjson\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m)\n\u001b[1;32m 10\u001b[0m results\u001b[38;5;241m.\u001b[39mraise_for_status()\n\u001b[1;32m 12\u001b[0m results\u001b[38;5;241m.\u001b[39mjson()\n",
"File \u001b[0;32m~/Library/Caches/pypoetry/virtualenvs/knowledge-7QbvxqGg-py3.11/lib/python3.11/site-packages/httpx/_models.py:762\u001b[0m, in \u001b[0;36mResponse.json\u001b[0;34m(self, **kwargs)\u001b[0m\n\u001b[1;32m 761\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mjson\u001b[39m(\u001b[38;5;28mself\u001b[39m, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs: typing\u001b[38;5;241m.\u001b[39mAny) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m typing\u001b[38;5;241m.\u001b[39mAny:\n\u001b[0;32m--> 762\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mjsonlib\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mloads\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcontent\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
"File \u001b[0;32m~/.pyenv/versions/3.11.4/lib/python3.11/json/__init__.py:346\u001b[0m, in \u001b[0;36mloads\u001b[0;34m(s, cls, object_hook, parse_float, parse_int, parse_constant, object_pairs_hook, **kw)\u001b[0m\n\u001b[1;32m 341\u001b[0m s \u001b[38;5;241m=\u001b[39m s\u001b[38;5;241m.\u001b[39mdecode(detect_encoding(s), \u001b[38;5;124m'\u001b[39m\u001b[38;5;124msurrogatepass\u001b[39m\u001b[38;5;124m'\u001b[39m)\n\u001b[1;32m 343\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m (\u001b[38;5;28mcls\u001b[39m \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;129;01mand\u001b[39;00m object_hook \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;129;01mand\u001b[39;00m\n\u001b[1;32m 344\u001b[0m parse_int \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;129;01mand\u001b[39;00m parse_float \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;129;01mand\u001b[39;00m\n\u001b[1;32m 345\u001b[0m parse_constant \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;129;01mand\u001b[39;00m object_pairs_hook \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m kw):\n\u001b[0;32m--> 346\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43m_default_decoder\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdecode\u001b[49m\u001b[43m(\u001b[49m\u001b[43ms\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 347\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mcls\u001b[39m \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 348\u001b[0m \u001b[38;5;28mcls\u001b[39m \u001b[38;5;241m=\u001b[39m JSONDecoder\n",
"File \u001b[0;32m~/.pyenv/versions/3.11.4/lib/python3.11/json/decoder.py:337\u001b[0m, in \u001b[0;36mJSONDecoder.decode\u001b[0;34m(self, s, _w)\u001b[0m\n\u001b[1;32m 332\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mdecode\u001b[39m(\u001b[38;5;28mself\u001b[39m, s, _w\u001b[38;5;241m=\u001b[39mWHITESPACE\u001b[38;5;241m.\u001b[39mmatch):\n\u001b[1;32m 333\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124;03m\"\"\"Return the Python representation of ``s`` (a ``str`` instance\u001b[39;00m\n\u001b[1;32m 334\u001b[0m \u001b[38;5;124;03m containing a JSON document).\u001b[39;00m\n\u001b[1;32m 335\u001b[0m \n\u001b[1;32m 336\u001b[0m \u001b[38;5;124;03m \"\"\"\u001b[39;00m\n\u001b[0;32m--> 337\u001b[0m obj, end \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mraw_decode\u001b[49m\u001b[43m(\u001b[49m\u001b[43ms\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43midx\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m_w\u001b[49m\u001b[43m(\u001b[49m\u001b[43ms\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m0\u001b[39;49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mend\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 338\u001b[0m end \u001b[38;5;241m=\u001b[39m _w(s, end)\u001b[38;5;241m.\u001b[39mend()\n\u001b[1;32m 339\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m end \u001b[38;5;241m!=\u001b[39m \u001b[38;5;28mlen\u001b[39m(s):\n",
"File \u001b[0;32m~/.pyenv/versions/3.11.4/lib/python3.11/json/decoder.py:355\u001b[0m, in \u001b[0;36mJSONDecoder.raw_decode\u001b[0;34m(self, s, idx)\u001b[0m\n\u001b[1;32m 353\u001b[0m obj, end \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mscan_once(s, idx)\n\u001b[1;32m 354\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mStopIteration\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m err:\n\u001b[0;32m--> 355\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m JSONDecodeError(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mExpecting value\u001b[39m\u001b[38;5;124m\"\u001b[39m, s, err\u001b[38;5;241m.\u001b[39mvalue) \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m 356\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m obj, end\n",
"\u001b[0;31mJSONDecodeError\u001b[0m: Expecting value: line 1 column 1 (char 0)"
]
}
],
"outputs": [],
"source": [
"# Retrieve 4 items with no summary.\n",
"results = client.post(f\"/chunks/retrieve\",\n",
Expand Down
42 changes: 41 additions & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

28 changes: 28 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,34 @@ include = [
[tool.poetry.scripts]
dewy = "dewy.main:run"

[tool.poe.tasks.check-ruff]
help = "Check ruff formatting and checks."
sequence = ["check-ruff-format", "check-ruff-checks"]

[tool.poe.tasks.fix-ruff]
help = "Fix ruff formatting and checks."
sequence = ["fix-ruff-format", "fix-ruff-checks"]

[tool.poe.tasks.check-ruff-checks]
cmd = "ruff check"

[tool.poe.tasks.check-ruff-format]
cmd = "ruff format --check"

[tool.poe.tasks.fix-ruff-checks]
cmd = "ruff check --fix"

[tool.poe.tasks.fix-ruff-format]
cmd = "ruff format"

[tool.poe.tasks.test]
help = "Run unit and feature tests"
cmd = "pytest"

[tool.poe.tasks.extract-openapi]
help = "Update openapi.toml from the swagger docs"
cmd = "python scripts/extract_openapi.py dewy.main:app"

[tool.poetry.dependencies]
python = "^3.11"
pydantic = "^2.5.3"
Expand Down
Loading

0 comments on commit 1e3d9ba

Please sign in to comment.