Skip to content

Commit

Permalink
feat: add Model API authentication (#106)
Browse files Browse the repository at this point in the history
Signed-off-by: Miguel.brandao <34553282+HolyMichael@users.noreply.github.com>
Signed-off-by: Panos Vagenas <35837085+vagenas@users.noreply.github.com>
Co-authored-by: Miguel.brandao <34553282+HolyMichael@users.noreply.github.com>
Co-authored-by: miguel <miguel.brandao@ibm.com>
  • Loading branch information
3 people authored Jun 22, 2023
1 parent bd9c4f4 commit 755f1a0
Show file tree
Hide file tree
Showing 8 changed files with 888 additions and 822 deletions.
33 changes: 24 additions & 9 deletions deepsearch/model/README.md
Original file line number Diff line number Diff line change
@@ -1,29 +1,41 @@
# Model API

> Currently in **beta**.
The Model API allows users to serve and integrate their own models.
The Model API is a unified and extensible inference API across different model kinds.

Built-in model kind support includes NLP annotators and QA generators.

## Installation
To use the Model API, install including the `api`
extra, i.e.:
To use the Model API, install including the `api` extra, i.e.:
- with poetry:
`poetry add "deepsearch-toolkit[api]"`
- with pip: `pip install "deepsearch-toolkit[api]"`

## Usage
To run a model, register it with a [`ModelApp`](server/model_app.py) and run the app:
## Basic usage
```python
from deepsearch.model.server.deepsearch_annotator_app import ModelApp
from deepsearch.model.server.config import Settings
from deepsearch.model.server.model_app import ModelApp

# (1) create an app
app = ModelApp(settings=Settings())

# (2) register your model(s)
model = ... # e.g. SimpleGeoNLPAnnotator()
app = ModelApp()
app.register_model(model)

# (3) run the app
app.run(host="127.0.0.1", port=8000)
```

### OpenAPI
### Settings
App configuration is done in [`Settings`](server/config.py) based on
[Pydantic settings management](https://docs.pydantic.dev/latest/usage/settings/).

E.g. the required API key can be set via env var `DS_MODEL_API_KEY` (for precedence rules,
check the
[Pydantic docs](https://docs.pydantic.dev/latest/usage/settings/#field-value-priority)).

### OpenAPI
The OpenAPI UI is served under `/docs`, e.g. http://127.0.0.1:8000/docs.

## Developing a new model
Expand All @@ -40,6 +52,9 @@ optional parameter `controller`.
- [Simple geo NLP annotator](examples/simple_geo_nlp_annotator/)
- [Dummy QA generator](examples/dummy_qa_generator/)

Note: these examples configure the app with API key "example123"; when running them, use
the same key to access the protected endpoints.

### Inference
As as example, an input payload for the `/predict` endpoint for the geography annotator
could look as follows (note that `deepsearch.res.ibm.com/x-deadline` should be a
Expand Down
4 changes: 3 additions & 1 deletion deepsearch/model/examples/dummy_nlp_annotator/main.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from deepsearch.model.examples.dummy_nlp_annotator.model import DummyNLPAnnotator
from deepsearch.model.server.config import Settings
from deepsearch.model.server.model_app import ModelApp


def run():
app = ModelApp()
settings = Settings(api_key="example123")
app = ModelApp(settings)
app.register_model(DummyNLPAnnotator())
app.run()

Expand Down
4 changes: 3 additions & 1 deletion deepsearch/model/examples/dummy_qa_generator/main.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from deepsearch.model.examples.dummy_qa_generator.model import DummyQAGenerator
from deepsearch.model.server.config import Settings
from deepsearch.model.server.model_app import ModelApp


def run():
app = ModelApp()
settings = Settings(api_key="example123")
app = ModelApp(settings)
app.register_model(DummyQAGenerator())
app.run()

Expand Down
4 changes: 3 additions & 1 deletion deepsearch/model/examples/simple_geo_nlp_annotator/main.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
from deepsearch.model.examples.simple_geo_nlp_annotator.model import ( # type: ignore
SimpleGeoNLPAnnotator,
)
from deepsearch.model.server.config import Settings
from deepsearch.model.server.model_app import ModelApp


def run():
app = ModelApp()
settings = Settings(api_key="example123")
app = ModelApp(settings)
app.register_model(SimpleGeoNLPAnnotator())
app.run()

Expand Down
8 changes: 8 additions & 0 deletions deepsearch/model/server/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from pydantic import BaseSettings, SecretStr


class Settings(BaseSettings):
api_key: SecretStr

class Config:
env_prefix = "DS_MODEL_"
30 changes: 22 additions & 8 deletions deepsearch/model/server/model_app.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from __future__ import annotations

import asyncio
import logging
import os
Expand All @@ -9,22 +7,26 @@
import uvicorn
from anyio import CapacityLimiter
from anyio.lowlevel import RunVar
from fastapi import FastAPI, HTTPException, Request, status
from fastapi import Depends, FastAPI, HTTPException, Request, Security, status
from fastapi.concurrency import run_in_threadpool
from fastapi.encoders import jsonable_encoder
from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse
from fastapi.security import APIKeyHeader

from deepsearch.model.base.controller import BaseController
from deepsearch.model.base.model import BaseDSModel
from deepsearch.model.server.config import Settings
from deepsearch.model.server.controller_factory import ControllerFactory
from deepsearch.model.server.inference_types import AppInferenceInput

logger = logging.getLogger("cps-fastapi")


class ModelApp:
def __init__(self):
def __init__(self, settings: Settings):
self._settings = settings

self.app = FastAPI()
self._controllers: Dict[str, BaseController] = {}
self._contr_factory = ControllerFactory()
Expand All @@ -48,21 +50,20 @@ async def health_check() -> dict:
return {"message": "HealthCheck"}

@self.app.get("/")
async def get_definitions() -> dict:
async def get_definitions(api_key=Depends(self._auth)) -> dict:
return {
name: controller.get_info()
for name, controller in self._controllers.items()
}

@self.app.get("/model/{model_name}")
async def get_model_specs(model_name: str) -> dict:
async def get_model_specs(model_name: str, api_key=Depends(self._auth)) -> dict:
controller = self._get_controller(model_name=model_name)
return controller.get_info()

@self.app.post("/model/{model_name}/predict", response_model=None)
async def predict(
model_name: str,
request: AppInferenceInput,
model_name: str, request: AppInferenceInput, api_key=Depends(self._auth)
) -> JSONResponse:
request_arrival_time = time.time()
try:
Expand Down Expand Up @@ -145,6 +146,19 @@ def _inference_process(
headers["X-request-id"] = str(request_dict["id"])
return JSONResponse(content=jsonable_encoder(result), headers=headers)

def _auth(self, header_api_key: str = Security(APIKeyHeader(name="Authorization"))):
request_api_key = (
header_api_key.replace("Bearer ", "")
.replace("bearer ", "")
.replace("Bearer: ", "")
.replace("bearer: ", "")
.strip()
)
if request_api_key != self._settings.api_key.get_secret_value():
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid API key"
)

def _get_controller(self, model_name: str) -> BaseController:
controller = self._controllers.get(model_name)
if controller is None:
Expand Down
Loading

0 comments on commit 755f1a0

Please sign in to comment.