Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add a script to extract openapi.yaml #44

Merged
merged 3 commits into from
Jan 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@
__pycache__
.env
.vscode/
openapi.yaml
5 changes: 4 additions & 1 deletion dewy/common/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,10 @@ async def init_pool(conn: asyncpg.Connection):


def _pg_pool(request: Request) -> asyncpg.Pool:
return request.state.pg_pool
pg_pool = request.state.pg_pool
if pg_pool is None:
raise ValueError("DB not configured. Unable to get pool.")
return pg_pool


PgPoolDep = Annotated[asyncpg.Pool, Depends(_pg_pool)]
Expand Down
71 changes: 6 additions & 65 deletions dewy/config.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
from typing import Any, Optional

from fastapi.routing import APIRoute
from pydantic import PostgresDsn, RedisDsn, ValidationInfo, field_validator
from pydantic_core import Url
from pydantic import PostgresDsn
from pydantic_settings import BaseSettings, SettingsConfigDict

from dewy.constants import Environment
Expand All @@ -21,8 +20,11 @@ class Config(BaseSettings):
SERVE_ADMIN_UI: bool = True
"""If true, serve the admin UI."""

DB: PostgresDsn
"""The Postgres database to connect to."""
DB: Optional[PostgresDsn] = None
"""The Postgres database to connect to.

If not provided, none of the CRUD methods will work.
"""

APPLY_MIGRATIONS: bool = True
"""Whether migrations should be applied to the database.
Expand All @@ -34,73 +36,12 @@ class Config(BaseSettings):
ENVIRONMENT: Environment = Environment.PRODUCTION
"""The environment the application is running in."""

REDIS: Optional[RedisDsn] = None
"""The Redis service to use for queueing, indexing and document storage."""

EMBEDDING_MODEL: str = ""
"""The embedding model to use.

This is a string of the form `<kind>[:<model>]` with the following options:

1. `local[:<path or repository>]` -- Run a model locally. If this is a path
it will attempt to load a model from that location. Otherwise, it should
be a Hugging Face repository from which to retrieve the model.
2. `openai[:<name>]` -- The named OpenAI model. `OPENAI_API_KEY` must be set.
3. `ollama:<name>` -- The named Ollama model. `OLLAMA_BASE_URL` must be set.

In each of these cases, you can omit the second part for the default model of the
given kind.

If unset, this will default to `"openai"` if an OpenAI API KEY is available and
otherwise will use `"local"`.

NOTE: Changing embedding models is not currently supported.
"""

LLM_MODEL: str = ""
"""The LLM model to use.

This is a string of the form `<kind>:<model>` with the following options:

1. `local[:<path or repository>]` -- Run a model locally. If this is a path
it will attempt to load a model from that location. Otherwise, it should
be a Hugging Face repository from which to retrieve the model.
2. `openai[:<name>]` -- The named OpenAI model. `OPENAI_API_KEY` must be set.
3. `ollama:<name>` -- The named Ollama model. `OLLAMA_BASE_URL` must be set.

In each of these cases, you can omit the second part for the default model of the
given kind.

If unset, this will default to `"openai"` if an OpenAI API KEY is available and
otherwise will use `"local"`.
"""

OPENAI_API_KEY: Optional[str] = None
""" The OpenAI API Key to use for OpenAI models.

This is required for using openai models.
"""

OLLAMA_BASE_URL: Optional[Url] = None
"""The Base URL for Ollama.

This is required for using ollama models.
"""

@field_validator("OLLAMA_BASE_URL")
def validate_ollama_base_url(cls, v, info: ValidationInfo):
MODELS = ["LLM_MODEL", "EMBEDDING_MODEL"]
if v is None:
for model in MODELS:
context = info.context
if context:
value = context.get(model, "")
if value.startswith("ollama"):
raise ValueError(
f"{info.field_name} must be set to use '{model}={value}'"
)
return v


settings = Config()

Expand Down
27 changes: 14 additions & 13 deletions dewy/main.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import contextlib
import os
from pathlib import Path
from typing import AsyncIterator, TypedDict
from typing import AsyncIterator, Optional, TypedDict

import asyncpg
import uvicorn
Expand All @@ -17,7 +17,7 @@


class State(TypedDict):
pg_pool: asyncpg.Pool
pg_pool: Optional[asyncpg.Pool]


# Resolve paths, independent of PWD
Expand All @@ -30,17 +30,18 @@ class State(TypedDict):
async def lifespan(_app: FastAPI) -> AsyncIterator[State]:
"""Function creating instances used during the lifespan of the service."""

# TODO: Look at https://gist.github.com/mattbillenstein/270a4d44cbdcb181ac2ed58526ae137d
# for simple migration scripts.
async with db.create_pool(settings.DB.unicode_string()) as pg_pool:
if settings.APPLY_MIGRATIONS:
async with pg_pool.acquire() as conn:
await apply_migrations(conn, migration_dir=migrations_path)

logger.info("Created database connection")
state = {
"pg_pool": pg_pool,
}
if settings.DB is not None:
async with db.create_pool(settings.DB.unicode_string()) as pg_pool:
if settings.APPLY_MIGRATIONS:
async with pg_pool.acquire() as conn:
await apply_migrations(conn, migration_dir=migrations_path)

logger.info("Created database connection")
state = State(pg_pool=pg_pool)
yield state
else:
logger.warn("No database configured. CRUD methods will fail.")
state = State(pg_pool=None)
yield state


Expand Down
58 changes: 9 additions & 49 deletions example_notebook.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -15,17 +15,9 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'id': 1, 'name': 'my_collection', 'text_embedding_model': 'openai:text-embedding-ada-002', 'text_distance_metric': 'cosine'}\n"
]
}
],
"outputs": [],
"source": [
"response = client.put(f\"/collections/\",\n",
" json = {\n",
Expand All @@ -39,17 +31,9 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'id': 1, 'collection_id': 1, 'url': 'https://arxiv.org/pdf/2305.14283.pdf', 'ingest_state': 'pending', 'ingest_error': None}\n"
]
}
],
"outputs": [],
"source": [
"# Add \"Query Rewriting for Retrieval-Augmented Large Language Models\"\n",
"response = client.put(f\"/documents/\",\n",
Expand All @@ -64,17 +48,9 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'id': 1, 'collection_id': 1, 'url': 'https://arxiv.org/pdf/2305.14283.pdf', 'ingest_state': 'ingested', 'ingest_error': None}\n"
]
}
],
"outputs": [],
"source": [
"# Report the status of the document ingestion.\n",
"response = client.get(f\"/documents/{document_id}\")\n",
Expand All @@ -83,25 +59,9 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": null,
"metadata": {},
"outputs": [
{
"ename": "JSONDecodeError",
"evalue": "Expecting value: line 1 column 1 (char 0)",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mJSONDecodeError\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn[7], line 9\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;66;03m# Retrieve 4 items with no summary.\u001b[39;00m\n\u001b[1;32m 2\u001b[0m results \u001b[38;5;241m=\u001b[39m client\u001b[38;5;241m.\u001b[39mpost(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m/chunks/retrieve\u001b[39m\u001b[38;5;124m\"\u001b[39m,\n\u001b[1;32m 3\u001b[0m json \u001b[38;5;241m=\u001b[39m {\n\u001b[1;32m 4\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mcollection_id\u001b[39m\u001b[38;5;124m\"\u001b[39m: collection_id,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 7\u001b[0m },\n\u001b[1;32m 8\u001b[0m timeout \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m)\n\u001b[0;32m----> 9\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[43mresults\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mjson\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m)\n\u001b[1;32m 10\u001b[0m results\u001b[38;5;241m.\u001b[39mraise_for_status()\n\u001b[1;32m 12\u001b[0m results\u001b[38;5;241m.\u001b[39mjson()\n",
"File \u001b[0;32m~/Library/Caches/pypoetry/virtualenvs/knowledge-7QbvxqGg-py3.11/lib/python3.11/site-packages/httpx/_models.py:762\u001b[0m, in \u001b[0;36mResponse.json\u001b[0;34m(self, **kwargs)\u001b[0m\n\u001b[1;32m 761\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mjson\u001b[39m(\u001b[38;5;28mself\u001b[39m, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs: typing\u001b[38;5;241m.\u001b[39mAny) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m typing\u001b[38;5;241m.\u001b[39mAny:\n\u001b[0;32m--> 762\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mjsonlib\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mloads\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcontent\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
"File \u001b[0;32m~/.pyenv/versions/3.11.4/lib/python3.11/json/__init__.py:346\u001b[0m, in \u001b[0;36mloads\u001b[0;34m(s, cls, object_hook, parse_float, parse_int, parse_constant, object_pairs_hook, **kw)\u001b[0m\n\u001b[1;32m 341\u001b[0m s \u001b[38;5;241m=\u001b[39m s\u001b[38;5;241m.\u001b[39mdecode(detect_encoding(s), \u001b[38;5;124m'\u001b[39m\u001b[38;5;124msurrogatepass\u001b[39m\u001b[38;5;124m'\u001b[39m)\n\u001b[1;32m 343\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m (\u001b[38;5;28mcls\u001b[39m \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;129;01mand\u001b[39;00m object_hook \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;129;01mand\u001b[39;00m\n\u001b[1;32m 344\u001b[0m parse_int \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;129;01mand\u001b[39;00m parse_float \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;129;01mand\u001b[39;00m\n\u001b[1;32m 345\u001b[0m parse_constant \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;129;01mand\u001b[39;00m object_pairs_hook \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m kw):\n\u001b[0;32m--> 346\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43m_default_decoder\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdecode\u001b[49m\u001b[43m(\u001b[49m\u001b[43ms\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 347\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mcls\u001b[39m \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 348\u001b[0m \u001b[38;5;28mcls\u001b[39m \u001b[38;5;241m=\u001b[39m JSONDecoder\n",
"File \u001b[0;32m~/.pyenv/versions/3.11.4/lib/python3.11/json/decoder.py:337\u001b[0m, in \u001b[0;36mJSONDecoder.decode\u001b[0;34m(self, s, _w)\u001b[0m\n\u001b[1;32m 332\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mdecode\u001b[39m(\u001b[38;5;28mself\u001b[39m, s, _w\u001b[38;5;241m=\u001b[39mWHITESPACE\u001b[38;5;241m.\u001b[39mmatch):\n\u001b[1;32m 333\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124;03m\"\"\"Return the Python representation of ``s`` (a ``str`` instance\u001b[39;00m\n\u001b[1;32m 334\u001b[0m \u001b[38;5;124;03m containing a JSON document).\u001b[39;00m\n\u001b[1;32m 335\u001b[0m \n\u001b[1;32m 336\u001b[0m \u001b[38;5;124;03m \"\"\"\u001b[39;00m\n\u001b[0;32m--> 337\u001b[0m obj, end \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mraw_decode\u001b[49m\u001b[43m(\u001b[49m\u001b[43ms\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43midx\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m_w\u001b[49m\u001b[43m(\u001b[49m\u001b[43ms\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m0\u001b[39;49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mend\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 338\u001b[0m end \u001b[38;5;241m=\u001b[39m _w(s, end)\u001b[38;5;241m.\u001b[39mend()\n\u001b[1;32m 339\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m end \u001b[38;5;241m!=\u001b[39m \u001b[38;5;28mlen\u001b[39m(s):\n",
"File \u001b[0;32m~/.pyenv/versions/3.11.4/lib/python3.11/json/decoder.py:355\u001b[0m, in \u001b[0;36mJSONDecoder.raw_decode\u001b[0;34m(self, s, idx)\u001b[0m\n\u001b[1;32m 353\u001b[0m obj, end \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mscan_once(s, idx)\n\u001b[1;32m 354\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mStopIteration\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m err:\n\u001b[0;32m--> 355\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m JSONDecodeError(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mExpecting value\u001b[39m\u001b[38;5;124m\"\u001b[39m, s, err\u001b[38;5;241m.\u001b[39mvalue) \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m 356\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m obj, end\n",
"\u001b[0;31mJSONDecodeError\u001b[0m: Expecting value: line 1 column 1 (char 0)"
]
}
],
"outputs": [],
"source": [
"# Retrieve 4 items with no summary.\n",
"results = client.post(f\"/chunks/retrieve\",\n",
Expand Down
42 changes: 41 additions & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

28 changes: 28 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,34 @@ include = [
[tool.poetry.scripts]
dewy = "dewy.main:run"

[tool.poe.tasks.check-ruff]
help = "Check ruff formatting and checks."
sequence = ["check-ruff-format", "check-ruff-checks"]

[tool.poe.tasks.fix-ruff]
help = "Fix ruff formatting and checks."
sequence = ["fix-ruff-format", "fix-ruff-checks"]

[tool.poe.tasks.check-ruff-checks]
cmd = "ruff check"

[tool.poe.tasks.check-ruff-format]
cmd = "ruff format --check"

[tool.poe.tasks.fix-ruff-checks]
cmd = "ruff check --fix"

[tool.poe.tasks.fix-ruff-format]
cmd = "ruff format"

[tool.poe.tasks.test]
help = "Run unit and feature tests"
cmd = "pytest"

[tool.poe.tasks.extract-openapi]
help = "Update openapi.toml from the swagger docs"
cmd = "python scripts/extract_openapi.py dewy.main:app"

[tool.poetry.dependencies]
python = "^3.11"
pydantic = "^2.5.3"
Expand Down
Loading
Loading