Skip to content

Commit

Permalink
* API support of timeout
Browse files Browse the repository at this point in the history
* clarified order of execution
  • Loading branch information
asofter committed Sep 23, 2023
1 parent 433ff99 commit 2a29a86
Show file tree
Hide file tree
Showing 6 changed files with 86 additions and 31 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added Swagger documentation on the [API](https://llm-guard.com/usage/api/) documentation page
- Added `fail_fast` flag to stop the execution after the first failure
- Updated API and Playground to support `fail_fast` flag
- Clarified order of execution in the documentation
- Added timeout configuration for API example

### Removed
-
Expand Down
2 changes: 1 addition & 1 deletion examples/api/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ install:
python -m spacy download en_core_web_trf

run:
@uvicorn app:app --reload
DEBUG=true uvicorn app:app --reload

build-docker:
@docker build -t $(DOCKER_IMAGE_NAME):$(VERSION) .
Expand Down
2 changes: 2 additions & 0 deletions examples/api/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ make run-docker
- `CACHE_MAX_SIZE` (int): Maximum number of items in the cache. Default is unlimited.
- `CACHE_TTL` (int): Time in seconds after which a cached item expires. Default is 1 hour.
- `SCAN_FAIL_FAST` (bool): Stop scanning after the first failed check. Default is `False`.
- `SCAN_PROMPT_TIMEOUT` (int): Time in seconds after which a prompt scan will timeout. Default is 10 seconds.
- `SCAN_OUTPUT_TIMEOUT` (int): Time in seconds after which an output scan will timeout. Default is 30 seconds.

### Scanners

Expand Down
104 changes: 74 additions & 30 deletions examples/api/app.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
import asyncio
import concurrent.futures
import logging
import time
from datetime import timedelta

import schemas
from cache import InMemoryCache
from config import get_env_config, load_scanners_from_config
from fastapi import FastAPI, status
from fastapi import FastAPI, HTTPException, status
from fastapi.encoders import jsonable_encoder
from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware
Expand Down Expand Up @@ -77,24 +81,47 @@ async def liveliness():
return JSONResponse({"status": "ready"})

@app.post("/analyze/output", tags=["Analyze"])
def analyze_output(request: schemas.AnalyzeOutputRequest) -> schemas.AnalyzeOutputResponse:
async def analyze_output(
request: schemas.AnalyzeOutputRequest,
) -> schemas.AnalyzeOutputResponse:
logger.debug(f"Received analyze request: {request}")

sanitized_output, results_valid, results_score = scan_output(
output_scanners, request.prompt, request.output, config["scan_fail_fast"]
)
response = schemas.AnalyzeOutputResponse(
sanitized_output=sanitized_output,
is_valid=all(results_valid.values()),
scanners=results_score,
)

logger.debug(f"Sanitized response with the score: {results_score}")
with concurrent.futures.ThreadPoolExecutor() as executor:
loop = asyncio.get_event_loop()
try:
start_time = time.monotonic()
sanitized_output, results_valid, results_score = await asyncio.wait_for(
loop.run_in_executor(
executor,
scan_output,
output_scanners,
request.prompt,
request.output,
config["scan_fail_fast"],
),
timeout=config["scan_output_timeout"],
)

response = schemas.AnalyzeOutputResponse(
sanitized_output=sanitized_output,
is_valid=all(results_valid.values()),
scanners=results_score,
)
elapsed_time = timedelta(seconds=time.monotonic() - start_time)
logger.debug(
f"Sanitized response with the score: {results_score}. Elapsed time: {elapsed_time}"
)
except asyncio.TimeoutError:
raise HTTPException(
status_code=status.HTTP_408_REQUEST_TIMEOUT, detail="Request timeout."
)

return response

@app.post("/analyze/prompt", tags=["Analyze"])
def analyze_prompt(request: schemas.AnalyzePromptRequest) -> schemas.AnalyzePromptResponse:
async def analyze_prompt(
request: schemas.AnalyzePromptRequest,
) -> schemas.AnalyzePromptResponse:
logger.debug(f"Received analyze request: {request}")
cached_result = cache.get(request.prompt)

Expand All @@ -104,18 +131,36 @@ def analyze_prompt(request: schemas.AnalyzePromptRequest) -> schemas.AnalyzeProm

return schemas.AnalyzePromptResponse(**cached_result)

sanitized_prompt, results_valid, results_score = scan_prompt(
input_scanners, request.prompt, config["scan_fail_fast"]
)
response = schemas.AnalyzePromptResponse(
sanitized_prompt=sanitized_prompt,
is_valid=all(results_valid.values()),
scanners=results_score,
)

cache.set(request.prompt, response.dict())

logger.debug(f"Sanitized response with the score: {results_score}")
with concurrent.futures.ThreadPoolExecutor() as executor:
loop = asyncio.get_event_loop()
try:
start_time = time.monotonic()
sanitized_prompt, results_valid, results_score = await asyncio.wait_for(
loop.run_in_executor(
executor,
scan_prompt,
input_scanners,
request.prompt,
config["scan_fail_fast"],
),
timeout=config["scan_prompt_timeout"],
)

response = schemas.AnalyzePromptResponse(
sanitized_prompt=sanitized_prompt,
is_valid=all(results_valid.values()),
scanners=results_score,
)
cache.set(request.prompt, response.dict())

elapsed_time = timedelta(seconds=time.monotonic() - start_time)
logger.debug(
f"Sanitized response with the score: {results_score}. Elapsed time: {elapsed_time}"
)
except asyncio.TimeoutError:
raise HTTPException(
status_code=status.HTTP_408_REQUEST_TIMEOUT, detail="Request timeout."
)

return response

Expand All @@ -127,16 +172,15 @@ def shutdown_event():
async def http_exception_handler(request, exc):
logger.warning(f"HTTP exception: {exc}. Request {request}")

return JSONResponse(str(exc.detail), status_code=exc.status_code)
return JSONResponse(
{"message": str(exc.detail), "details": None}, status_code=exc.status_code
)

@app.exception_handler(RequestValidationError)
async def validation_exception_handler(request, exc):
logger.warning(f"Invalid request: {exc}. Request {request}")

response = {"message": "Validation failed", "details": []}
for error in exc.errors():
response["details"].append(f"{'.'.join(error['loc'])}: {error['msg']}")

response = {"message": "Validation failed", "details": exc.errors()}
return JSONResponse(
jsonable_encoder(response), status_code=status.HTTP_422_UNPROCESSABLE_ENTITY
)
Expand Down
6 changes: 6 additions & 0 deletions examples/api/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,12 @@ def get_env_config() -> Dict:
"scan_fail_fast": os.environ.get(
"SCAN_FAIL_FAST", False
), # If true, will stop scanning after the first scanner fails. Default is false.
"scan_prompt_timeout": os.environ.get(
"SCAN_PROMPT_TIMEOUT", 10
), # Time in seconds after which a prompt scan will timeout. Default is 10 seconds.
"scan_output_timeout": os.environ.get(
"SCAN_OUTPUT_TIMEOUT", 30
), # Time in seconds after which an output scan will timeout. Default is 30 seconds.
"cache_ttl": os.environ.get(
"CACHE_TTL", 60 * 60
), # Time in seconds after which a cached item expires. Default is 1 hour.
Expand Down
1 change: 1 addition & 0 deletions examples/api/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
asyncio==3.4.3
fastapi==0.103.1
llm-guard==0.2.2
pydantic==1.10.12
Expand Down

0 comments on commit 2a29a86

Please sign in to comment.