From 2a29a86eafe34e1531903fe9285380c85f6a9d5b Mon Sep 17 00:00:00 2001 From: asofter Date: Sat, 23 Sep 2023 20:58:39 +0200 Subject: [PATCH] * API support of timeout * clarified order of execution --- CHANGELOG.md | 2 + examples/api/Makefile | 2 +- examples/api/README.md | 2 + examples/api/app.py | 104 ++++++++++++++++++++++++---------- examples/api/config.py | 6 ++ examples/api/requirements.txt | 1 + 6 files changed, 86 insertions(+), 31 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 1350f93e..c6e8b770 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -17,6 +17,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added Swagger documentation on the [API](https://llm-guard.com/usage/api/) documentation page - Added `fail_fast` flag to stop the execution after the first failure - Updated API and Playground to support `fail_fast` flag + - Clarified order of execution in the documentation +- Added timeout configuration for API example ### Removed - diff --git a/examples/api/Makefile b/examples/api/Makefile index 52282bb9..29f19586 100644 --- a/examples/api/Makefile +++ b/examples/api/Makefile @@ -19,7 +19,7 @@ install: python -m spacy download en_core_web_trf run: - @uvicorn app:app --reload + DEBUG=true uvicorn app:app --reload build-docker: @docker build -t $(DOCKER_IMAGE_NAME):$(VERSION) . diff --git a/examples/api/README.md b/examples/api/README.md index 0b563951..19eacfe3 100644 --- a/examples/api/README.md +++ b/examples/api/README.md @@ -38,6 +38,8 @@ make run-docker - `CACHE_MAX_SIZE` (int): Maximum number of items in the cache. Default is unlimited. - `CACHE_TTL` (int): Time in seconds after which a cached item expires. Default is 1 hour. - `SCAN_FAIL_FAST` (bool): Stop scanning after the first failed check. Default is `False`. +- `SCAN_PROMPT_TIMEOUT` (int): Time in seconds after which a prompt scan will timeout. Default is 10 seconds. +- `SCAN_OUTPUT_TIMEOUT` (int): Time in seconds after which an output scan will timeout. Default is 30 seconds. ### Scanners diff --git a/examples/api/app.py b/examples/api/app.py index 70327512..dc2b8c37 100644 --- a/examples/api/app.py +++ b/examples/api/app.py @@ -1,9 +1,13 @@ +import asyncio +import concurrent.futures import logging +import time +from datetime import timedelta import schemas from cache import InMemoryCache from config import get_env_config, load_scanners_from_config -from fastapi import FastAPI, status +from fastapi import FastAPI, HTTPException, status from fastapi.encoders import jsonable_encoder from fastapi.exceptions import RequestValidationError from fastapi.middleware.cors import CORSMiddleware @@ -77,24 +81,47 @@ async def liveliness(): return JSONResponse({"status": "ready"}) @app.post("/analyze/output", tags=["Analyze"]) - def analyze_output(request: schemas.AnalyzeOutputRequest) -> schemas.AnalyzeOutputResponse: + async def analyze_output( + request: schemas.AnalyzeOutputRequest, + ) -> schemas.AnalyzeOutputResponse: logger.debug(f"Received analyze request: {request}") - sanitized_output, results_valid, results_score = scan_output( - output_scanners, request.prompt, request.output, config["scan_fail_fast"] - ) - response = schemas.AnalyzeOutputResponse( - sanitized_output=sanitized_output, - is_valid=all(results_valid.values()), - scanners=results_score, - ) - - logger.debug(f"Sanitized response with the score: {results_score}") + with concurrent.futures.ThreadPoolExecutor() as executor: + loop = asyncio.get_event_loop() + try: + start_time = time.monotonic() + sanitized_output, results_valid, results_score = await asyncio.wait_for( + loop.run_in_executor( + executor, + scan_output, + output_scanners, + request.prompt, + request.output, + config["scan_fail_fast"], + ), + timeout=config["scan_output_timeout"], + ) + + response = schemas.AnalyzeOutputResponse( + sanitized_output=sanitized_output, + is_valid=all(results_valid.values()), + scanners=results_score, + ) + elapsed_time = timedelta(seconds=time.monotonic() - start_time) + logger.debug( + f"Sanitized response with the score: {results_score}. Elapsed time: {elapsed_time}" + ) + except asyncio.TimeoutError: + raise HTTPException( + status_code=status.HTTP_408_REQUEST_TIMEOUT, detail="Request timeout." + ) return response @app.post("/analyze/prompt", tags=["Analyze"]) - def analyze_prompt(request: schemas.AnalyzePromptRequest) -> schemas.AnalyzePromptResponse: + async def analyze_prompt( + request: schemas.AnalyzePromptRequest, + ) -> schemas.AnalyzePromptResponse: logger.debug(f"Received analyze request: {request}") cached_result = cache.get(request.prompt) @@ -104,18 +131,36 @@ def analyze_prompt(request: schemas.AnalyzePromptRequest) -> schemas.AnalyzeProm return schemas.AnalyzePromptResponse(**cached_result) - sanitized_prompt, results_valid, results_score = scan_prompt( - input_scanners, request.prompt, config["scan_fail_fast"] - ) - response = schemas.AnalyzePromptResponse( - sanitized_prompt=sanitized_prompt, - is_valid=all(results_valid.values()), - scanners=results_score, - ) - - cache.set(request.prompt, response.dict()) - - logger.debug(f"Sanitized response with the score: {results_score}") + with concurrent.futures.ThreadPoolExecutor() as executor: + loop = asyncio.get_event_loop() + try: + start_time = time.monotonic() + sanitized_prompt, results_valid, results_score = await asyncio.wait_for( + loop.run_in_executor( + executor, + scan_prompt, + input_scanners, + request.prompt, + config["scan_fail_fast"], + ), + timeout=config["scan_prompt_timeout"], + ) + + response = schemas.AnalyzePromptResponse( + sanitized_prompt=sanitized_prompt, + is_valid=all(results_valid.values()), + scanners=results_score, + ) + cache.set(request.prompt, response.dict()) + + elapsed_time = timedelta(seconds=time.monotonic() - start_time) + logger.debug( + f"Sanitized response with the score: {results_score}. Elapsed time: {elapsed_time}" + ) + except asyncio.TimeoutError: + raise HTTPException( + status_code=status.HTTP_408_REQUEST_TIMEOUT, detail="Request timeout." + ) return response @@ -127,16 +172,15 @@ def shutdown_event(): async def http_exception_handler(request, exc): logger.warning(f"HTTP exception: {exc}. Request {request}") - return JSONResponse(str(exc.detail), status_code=exc.status_code) + return JSONResponse( + {"message": str(exc.detail), "details": None}, status_code=exc.status_code + ) @app.exception_handler(RequestValidationError) async def validation_exception_handler(request, exc): logger.warning(f"Invalid request: {exc}. Request {request}") - response = {"message": "Validation failed", "details": []} - for error in exc.errors(): - response["details"].append(f"{'.'.join(error['loc'])}: {error['msg']}") - + response = {"message": "Validation failed", "details": exc.errors()} return JSONResponse( jsonable_encoder(response), status_code=status.HTTP_422_UNPROCESSABLE_ENTITY ) diff --git a/examples/api/config.py b/examples/api/config.py index 7b30886a..22ead5f7 100644 --- a/examples/api/config.py +++ b/examples/api/config.py @@ -18,6 +18,12 @@ def get_env_config() -> Dict: "scan_fail_fast": os.environ.get( "SCAN_FAIL_FAST", False ), # If true, will stop scanning after the first scanner fails. Default is false. + "scan_prompt_timeout": os.environ.get( + "SCAN_PROMPT_TIMEOUT", 10 + ), # Time in seconds after which a prompt scan will timeout. Default is 10 seconds. + "scan_output_timeout": os.environ.get( + "SCAN_OUTPUT_TIMEOUT", 30 + ), # Time in seconds after which an output scan will timeout. Default is 30 seconds. "cache_ttl": os.environ.get( "CACHE_TTL", 60 * 60 ), # Time in seconds after which a cached item expires. Default is 1 hour. diff --git a/examples/api/requirements.txt b/examples/api/requirements.txt index f2ba1cfe..9883e9e0 100644 --- a/examples/api/requirements.txt +++ b/examples/api/requirements.txt @@ -1,3 +1,4 @@ +asyncio==3.4.3 fastapi==0.103.1 llm-guard==0.2.2 pydantic==1.10.12