diff --git a/.gitignore b/.gitignore index c5a859d..78b6a8d 100644 --- a/.gitignore +++ b/.gitignore @@ -2,3 +2,4 @@ __pycache__ .env .vscode/ +openapi.yaml \ No newline at end of file diff --git a/dewy/common/db.py b/dewy/common/db.py index a2bf0df..7da5aae 100644 --- a/dewy/common/db.py +++ b/dewy/common/db.py @@ -30,7 +30,10 @@ async def init_pool(conn: asyncpg.Connection): def _pg_pool(request: Request) -> asyncpg.Pool: - return request.state.pg_pool + pg_pool = request.state.pg_pool + if pg_pool is None: + raise ValueError("DB not configured. Unable to get pool.") + return pg_pool PgPoolDep = Annotated[asyncpg.Pool, Depends(_pg_pool)] diff --git a/dewy/config.py b/dewy/config.py index 9adfa4e..c56b022 100644 --- a/dewy/config.py +++ b/dewy/config.py @@ -1,8 +1,7 @@ from typing import Any, Optional from fastapi.routing import APIRoute -from pydantic import PostgresDsn, RedisDsn, ValidationInfo, field_validator -from pydantic_core import Url +from pydantic import PostgresDsn from pydantic_settings import BaseSettings, SettingsConfigDict from dewy.constants import Environment @@ -21,8 +20,11 @@ class Config(BaseSettings): SERVE_ADMIN_UI: bool = True """If true, serve the admin UI.""" - DB: PostgresDsn - """The Postgres database to connect to.""" + DB: Optional[PostgresDsn] = None + """The Postgres database to connect to. + + If not provided, none of the CRUD methods will work. + """ APPLY_MIGRATIONS: bool = True """Whether migrations should be applied to the database. @@ -34,73 +36,12 @@ class Config(BaseSettings): ENVIRONMENT: Environment = Environment.PRODUCTION """The environment the application is running in.""" - REDIS: Optional[RedisDsn] = None - """The Redis service to use for queueing, indexing and document storage.""" - - EMBEDDING_MODEL: str = "" - """The embedding model to use. - - This is a string of the form `[:]` with the following options: - - 1. `local[:]` -- Run a model locally. If this is a path - it will attempt to load a model from that location. Otherwise, it should - be a Hugging Face repository from which to retrieve the model. - 2. `openai[:]` -- The named OpenAI model. `OPENAI_API_KEY` must be set. - 3. `ollama:` -- The named Ollama model. `OLLAMA_BASE_URL` must be set. - - In each of these cases, you can omit the second part for the default model of the - given kind. - - If unset, this will default to `"openai"` if an OpenAI API KEY is available and - otherwise will use `"local"`. - - NOTE: Changing embedding models is not currently supported. - """ - - LLM_MODEL: str = "" - """The LLM model to use. - - This is a string of the form `:` with the following options: - - 1. `local[:]` -- Run a model locally. If this is a path - it will attempt to load a model from that location. Otherwise, it should - be a Hugging Face repository from which to retrieve the model. - 2. `openai[:]` -- The named OpenAI model. `OPENAI_API_KEY` must be set. - 3. `ollama:` -- The named Ollama model. `OLLAMA_BASE_URL` must be set. - - In each of these cases, you can omit the second part for the default model of the - given kind. - - If unset, this will default to `"openai"` if an OpenAI API KEY is available and - otherwise will use `"local"`. - """ - OPENAI_API_KEY: Optional[str] = None """ The OpenAI API Key to use for OpenAI models. This is required for using openai models. """ - OLLAMA_BASE_URL: Optional[Url] = None - """The Base URL for Ollama. - - This is required for using ollama models. - """ - - @field_validator("OLLAMA_BASE_URL") - def validate_ollama_base_url(cls, v, info: ValidationInfo): - MODELS = ["LLM_MODEL", "EMBEDDING_MODEL"] - if v is None: - for model in MODELS: - context = info.context - if context: - value = context.get(model, "") - if value.startswith("ollama"): - raise ValueError( - f"{info.field_name} must be set to use '{model}={value}'" - ) - return v - settings = Config() diff --git a/dewy/main.py b/dewy/main.py index e896edf..4eabb61 100644 --- a/dewy/main.py +++ b/dewy/main.py @@ -1,7 +1,7 @@ import contextlib import os from pathlib import Path -from typing import AsyncIterator, TypedDict +from typing import AsyncIterator, Optional, TypedDict import asyncpg import uvicorn @@ -17,7 +17,7 @@ class State(TypedDict): - pg_pool: asyncpg.Pool + pg_pool: Optional[asyncpg.Pool] # Resolve paths, independent of PWD @@ -30,17 +30,18 @@ class State(TypedDict): async def lifespan(_app: FastAPI) -> AsyncIterator[State]: """Function creating instances used during the lifespan of the service.""" - # TODO: Look at https://gist.github.com/mattbillenstein/270a4d44cbdcb181ac2ed58526ae137d - # for simple migration scripts. - async with db.create_pool(settings.DB.unicode_string()) as pg_pool: - if settings.APPLY_MIGRATIONS: - async with pg_pool.acquire() as conn: - await apply_migrations(conn, migration_dir=migrations_path) - - logger.info("Created database connection") - state = { - "pg_pool": pg_pool, - } + if settings.DB is not None: + async with db.create_pool(settings.DB.unicode_string()) as pg_pool: + if settings.APPLY_MIGRATIONS: + async with pg_pool.acquire() as conn: + await apply_migrations(conn, migration_dir=migrations_path) + + logger.info("Created database connection") + state = State(pg_pool=pg_pool) + yield state + else: + logger.warn("No database configured. CRUD methods will fail.") + state = State(pg_pool=None) yield state diff --git a/example_notebook.ipynb b/example_notebook.ipynb index fc3167c..aac1b97 100644 --- a/example_notebook.ipynb +++ b/example_notebook.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "code", - "execution_count": 1, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -15,17 +15,9 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "{'id': 1, 'name': 'my_collection', 'text_embedding_model': 'openai:text-embedding-ada-002', 'text_distance_metric': 'cosine'}\n" - ] - } - ], + "outputs": [], "source": [ "response = client.put(f\"/collections/\",\n", " json = {\n", @@ -39,17 +31,9 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "{'id': 1, 'collection_id': 1, 'url': 'https://arxiv.org/pdf/2305.14283.pdf', 'ingest_state': 'pending', 'ingest_error': None}\n" - ] - } - ], + "outputs": [], "source": [ "# Add \"Query Rewriting for Retrieval-Augmented Large Language Models\"\n", "response = client.put(f\"/documents/\",\n", @@ -64,17 +48,9 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "{'id': 1, 'collection_id': 1, 'url': 'https://arxiv.org/pdf/2305.14283.pdf', 'ingest_state': 'ingested', 'ingest_error': None}\n" - ] - } - ], + "outputs": [], "source": [ "# Report the status of the document ingestion.\n", "response = client.get(f\"/documents/{document_id}\")\n", @@ -83,25 +59,9 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "ename": "JSONDecodeError", - "evalue": "Expecting value: line 1 column 1 (char 0)", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mJSONDecodeError\u001b[0m Traceback (most recent call last)", - "Cell \u001b[0;32mIn[7], line 9\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;66;03m# Retrieve 4 items with no summary.\u001b[39;00m\n\u001b[1;32m 2\u001b[0m results \u001b[38;5;241m=\u001b[39m client\u001b[38;5;241m.\u001b[39mpost(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m/chunks/retrieve\u001b[39m\u001b[38;5;124m\"\u001b[39m,\n\u001b[1;32m 3\u001b[0m json \u001b[38;5;241m=\u001b[39m {\n\u001b[1;32m 4\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mcollection_id\u001b[39m\u001b[38;5;124m\"\u001b[39m: collection_id,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 7\u001b[0m },\n\u001b[1;32m 8\u001b[0m timeout \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m)\n\u001b[0;32m----> 9\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[43mresults\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mjson\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m)\n\u001b[1;32m 10\u001b[0m results\u001b[38;5;241m.\u001b[39mraise_for_status()\n\u001b[1;32m 12\u001b[0m results\u001b[38;5;241m.\u001b[39mjson()\n", - "File \u001b[0;32m~/Library/Caches/pypoetry/virtualenvs/knowledge-7QbvxqGg-py3.11/lib/python3.11/site-packages/httpx/_models.py:762\u001b[0m, in \u001b[0;36mResponse.json\u001b[0;34m(self, **kwargs)\u001b[0m\n\u001b[1;32m 761\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mjson\u001b[39m(\u001b[38;5;28mself\u001b[39m, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs: typing\u001b[38;5;241m.\u001b[39mAny) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m typing\u001b[38;5;241m.\u001b[39mAny:\n\u001b[0;32m--> 762\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mjsonlib\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mloads\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcontent\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", - "File \u001b[0;32m~/.pyenv/versions/3.11.4/lib/python3.11/json/__init__.py:346\u001b[0m, in \u001b[0;36mloads\u001b[0;34m(s, cls, object_hook, parse_float, parse_int, parse_constant, object_pairs_hook, **kw)\u001b[0m\n\u001b[1;32m 341\u001b[0m s \u001b[38;5;241m=\u001b[39m s\u001b[38;5;241m.\u001b[39mdecode(detect_encoding(s), \u001b[38;5;124m'\u001b[39m\u001b[38;5;124msurrogatepass\u001b[39m\u001b[38;5;124m'\u001b[39m)\n\u001b[1;32m 343\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m (\u001b[38;5;28mcls\u001b[39m \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;129;01mand\u001b[39;00m object_hook \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;129;01mand\u001b[39;00m\n\u001b[1;32m 344\u001b[0m parse_int \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;129;01mand\u001b[39;00m parse_float \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;129;01mand\u001b[39;00m\n\u001b[1;32m 345\u001b[0m parse_constant \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;129;01mand\u001b[39;00m object_pairs_hook \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m kw):\n\u001b[0;32m--> 346\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43m_default_decoder\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdecode\u001b[49m\u001b[43m(\u001b[49m\u001b[43ms\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 347\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mcls\u001b[39m \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 348\u001b[0m \u001b[38;5;28mcls\u001b[39m \u001b[38;5;241m=\u001b[39m JSONDecoder\n", - "File \u001b[0;32m~/.pyenv/versions/3.11.4/lib/python3.11/json/decoder.py:337\u001b[0m, in \u001b[0;36mJSONDecoder.decode\u001b[0;34m(self, s, _w)\u001b[0m\n\u001b[1;32m 332\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mdecode\u001b[39m(\u001b[38;5;28mself\u001b[39m, s, _w\u001b[38;5;241m=\u001b[39mWHITESPACE\u001b[38;5;241m.\u001b[39mmatch):\n\u001b[1;32m 333\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124;03m\"\"\"Return the Python representation of ``s`` (a ``str`` instance\u001b[39;00m\n\u001b[1;32m 334\u001b[0m \u001b[38;5;124;03m containing a JSON document).\u001b[39;00m\n\u001b[1;32m 335\u001b[0m \n\u001b[1;32m 336\u001b[0m \u001b[38;5;124;03m \"\"\"\u001b[39;00m\n\u001b[0;32m--> 337\u001b[0m obj, end \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mraw_decode\u001b[49m\u001b[43m(\u001b[49m\u001b[43ms\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43midx\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m_w\u001b[49m\u001b[43m(\u001b[49m\u001b[43ms\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m0\u001b[39;49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mend\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 338\u001b[0m end \u001b[38;5;241m=\u001b[39m _w(s, end)\u001b[38;5;241m.\u001b[39mend()\n\u001b[1;32m 339\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m end \u001b[38;5;241m!=\u001b[39m \u001b[38;5;28mlen\u001b[39m(s):\n", - "File \u001b[0;32m~/.pyenv/versions/3.11.4/lib/python3.11/json/decoder.py:355\u001b[0m, in \u001b[0;36mJSONDecoder.raw_decode\u001b[0;34m(self, s, idx)\u001b[0m\n\u001b[1;32m 353\u001b[0m obj, end \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mscan_once(s, idx)\n\u001b[1;32m 354\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mStopIteration\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m err:\n\u001b[0;32m--> 355\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m JSONDecodeError(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mExpecting value\u001b[39m\u001b[38;5;124m\"\u001b[39m, s, err\u001b[38;5;241m.\u001b[39mvalue) \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m 356\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m obj, end\n", - "\u001b[0;31mJSONDecodeError\u001b[0m: Expecting value: line 1 column 1 (char 0)" - ] - } - ], + "outputs": [], "source": [ "# Retrieve 4 items with no summary.\n", "results = client.post(f\"/chunks/retrieve\",\n", diff --git a/poetry.lock b/poetry.lock index f48da10..91574a8 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1821,6 +1821,17 @@ files = [ qa = ["flake8 (==3.8.3)", "mypy (==0.782)"] testing = ["docopt", "pytest (<6.0.0)"] +[[package]] +name = "pastel" +version = "0.2.1" +description = "Bring colors to your terminal." +optional = false +python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" +files = [ + {file = "pastel-0.2.1-py2.py3-none-any.whl", hash = "sha256:4349225fcdf6c2bb34d483e523475de5bb04a5c10ef711263452cb37d7dd4364"}, + {file = "pastel-0.2.1.tar.gz", hash = "sha256:e6581ac04e973cac858828c6202c1e1e81fee1dc7de7683f3e1ffe0bfd8a573d"}, +] + [[package]] name = "pexpect" version = "4.9.0" @@ -1878,6 +1889,24 @@ files = [ dev = ["pre-commit", "tox"] testing = ["pytest", "pytest-benchmark"] +[[package]] +name = "poethepoet" +version = "0.24.4" +description = "A task runner that works well with poetry." +optional = false +python-versions = ">=3.8" +files = [ + {file = "poethepoet-0.24.4-py3-none-any.whl", hash = "sha256:fb4ea35d7f40fe2081ea917d2e4102e2310fda2cde78974050ca83896e229075"}, + {file = "poethepoet-0.24.4.tar.gz", hash = "sha256:ff4220843a87c888cbcb5312c8905214701d0af60ac7271795baa8369b428fef"}, +] + +[package.dependencies] +pastel = ">=0.2.1,<0.3.0" +tomli = ">=1.2.2" + +[package.extras] +poetry-plugin = ["poetry (>=1.0,<2.0)"] + [[package]] name = "prompt-toolkit" version = "3.0.43" @@ -3169,6 +3198,17 @@ dev = ["tokenizers[testing]"] docs = ["setuptools_rust", "sphinx", "sphinx_rtd_theme"] testing = ["black (==22.3)", "datasets", "numpy", "pytest", "requests"] +[[package]] +name = "tomli" +version = "2.0.1" +description = "A lil' TOML parser" +optional = false +python-versions = ">=3.7" +files = [ + {file = "tomli-2.0.1-py3-none-any.whl", hash = "sha256:939de3e7a6161af0c887ef91b7d41a53e7c5a1ca976325f429cb46ea9bc30ecc"}, + {file = "tomli-2.0.1.tar.gz", hash = "sha256:de526c12914f0c550d15924c62d72abc48d6fe7364aa87328337a31007fe8a4f"}, +] + [[package]] name = "torch" version = "2.1.2" @@ -3651,4 +3691,4 @@ multidict = ">=4.0" [metadata] lock-version = "2.0" python-versions = "^3.11" -content-hash = "71fde2d93537bb9f9de3a7b492e02cfa48d381d19f73c49daeac56251f8e02df" +content-hash = "d093075ba2c0fada13b8045f39f45e504eabbc653b2f337873dfaac787c52e68" diff --git a/pyproject.toml b/pyproject.toml index 3017975..5bfa50b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,6 +13,34 @@ include = [ [tool.poetry.scripts] dewy = "dewy.main:run" +[tool.poe.tasks.check-ruff] +help = "Check ruff formatting and checks." +sequence = ["check-ruff-format", "check-ruff-checks"] + +[tool.poe.tasks.fix-ruff] +help = "Fix ruff formatting and checks." +sequence = ["fix-ruff-format", "fix-ruff-checks"] + +[tool.poe.tasks.check-ruff-checks] +cmd = "ruff check" + +[tool.poe.tasks.check-ruff-format] +cmd = "ruff format --check" + +[tool.poe.tasks.fix-ruff-checks] +cmd = "ruff check --fix" + +[tool.poe.tasks.fix-ruff-format] +cmd = "ruff format" + +[tool.poe.tasks.test] +help = "Run unit and feature tests" +cmd = "pytest" + +[tool.poe.tasks.extract-openapi] +help = "Update openapi.toml from the swagger docs" +cmd = "python scripts/extract_openapi.py dewy.main:app" + [tool.poetry.dependencies] python = "^3.11" pydantic = "^2.5.3" diff --git a/scripts/extract_openapi.py b/scripts/extract_openapi.py new file mode 100644 index 0000000..a818cd0 --- /dev/null +++ b/scripts/extract_openapi.py @@ -0,0 +1,37 @@ +# Export script based on https://www.doctave.com/blog/python-export-fastapi-openapi-spec#step-2-create-the-export-script +import argparse +import json +import sys + +import yaml +from uvicorn.importer import import_from_string + +parser = argparse.ArgumentParser(prog="extract-openapi.py") +parser.add_argument( + "app", help='App import string. Eg. "main:app"', default="dewy.main:app" +) +parser.add_argument("--app-dir", help="Directory containing the app", default=None) +parser.add_argument( + "--out", help="Output file ending in .json or .yaml", default="openapi.yaml" +) + +if __name__ == "__main__": + args = parser.parse_args() + + if args.app_dir is not None: + print(f"adding {args.app_dir} to sys.path") + sys.path.insert(0, args.app_dir) + + print(f"importing app from {args.app}") + app = import_from_string(args.app) + openapi = app.openapi() + version = openapi.get("openapi", "unknown version") + + print(f"writing openapi spec v{version}") + with open(args.out, "w") as f: + if args.out.endswith(".json"): + json.dump(openapi, f, indent=2) + else: + yaml.dump(openapi, f, sort_keys=False) + + print(f"spec written to {args.out}")