Skip to content

Commit

Permalink
Merge pull request #249 from nulib/4774-red-team-test
Browse files Browse the repository at this point in the history
Create http endpoint for chat for testing
  • Loading branch information
kdid committed Aug 26, 2024
2 parents aeeeb51 + 5017697 commit 5597f6a
Show file tree
Hide file tree
Showing 6 changed files with 463 additions and 0 deletions.
43 changes: 43 additions & 0 deletions chat/src/handlers/chat_sync.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@

import json
import logging
import os
from http_event_config import HTTPEventConfig
from helpers.http_response import HTTPResponse
from honeybadger import honeybadger

honeybadger.configure()
logging.getLogger('honeybadger').addHandler(logging.StreamHandler())

RESPONSE_TYPES = {
"base": ["answer", "ref"],
"debug": ["answer", "attributes", "azure_endpoint", "deployment_name", "is_superuser", "k", "openai_api_version", "prompt", "question", "ref", "temperature", "text_key", "token_counts"],
"log": ["answer", "deployment_name", "is_superuser", "k", "openai_api_version", "prompt", "question", "ref", "size", "source_documents", "temperature", "token_counts"],
"error": ["question", "error", "source_documents"]
}

def handler(event, context):
print(f'Event: {event}')

config = HTTPEventConfig(event)

if not config.is_logged_in:
return {"statusCode": 401, "body": "Unauthorized"}

if config.question is None or config.question == "":
return {"statusCode": 400, "body": "Question cannot be blank"}

if not os.getenv("SKIP_WEAVIATE_SETUP"):
config.setup_llm_request()
response = HTTPResponse(config)
final_response = response.prepare_response()
if "error" in final_response:
logging.error(f'Error: {final_response["error"]}')
return {"statusCode": 500, "body": "Internal Server Error"}
else:
return {"statusCode": 200, "body": json.dumps(reshape_response(final_response, 'debug' if config.debug_mode else 'base'))}

return {"statusCode": 200}

def reshape_response(response, type):
return {k: response[k] for k in RESPONSE_TYPES[type]}
60 changes: 60 additions & 0 deletions chat/src/helpers/http_response.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
from helpers.metrics import debug_response
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnableLambda, RunnablePassthrough

def extract_prompt_value(v):
if isinstance(v, list):
return [extract_prompt_value(item) for item in v]
elif isinstance(v, dict) and 'label' in v:
return [v.get('label')]
else:
return v

class HTTPResponse:
def __init__(self, config):
self.config = config
self.store = {}

def debug_response_passthrough(self):
return RunnableLambda(lambda x: debug_response(self.config, x, self.original_question))

def original_question_passthrough(self):
def get_and_send_original_question(docs):
source_documents = []
for doc in docs["context"]:
doc.metadata = {key: extract_prompt_value(doc.metadata.get(key)) for key in self.config.attributes if key in doc.metadata}
source_document = doc.metadata.copy()
source_document["content"] = doc.page_content
source_documents.append(source_document)

original_question = {
"question": self.config.question,
"source_documents": source_documents,
}

self.original_question = original_question
return docs

return RunnablePassthrough(get_and_send_original_question)

def prepare_response(self):
try:
retriever = self.config.opensearch.as_retriever(search_type="similarity", search_kwargs={"k": self.config.k, "size": self.config.size, "_source": {"excludes": ["embedding"]}})
chain = (
{"context": retriever, "question": RunnablePassthrough()}
| self.original_question_passthrough()
| self.config.prompt
| self.config.client
| StrOutputParser()
| self.debug_response_passthrough()
)
response = chain.invoke(self.config.question)
except Exception as err:
response = {
"question": self.config.question,
"error": str(err),
"source_documents": [],
}
return response


188 changes: 188 additions & 0 deletions chat/src/http_event_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,188 @@
import os
import json

from dataclasses import dataclass, field

from langchain_core.prompts import ChatPromptTemplate
from setup import (
opensearch_client,
opensearch_vector_store,
openai_chat_client,
)
from typing import List
from helpers.apitoken import ApiToken
from helpers.prompts import document_template, prompt_template

CHAIN_TYPE = "stuff"
DOCUMENT_VARIABLE_NAME = "context"
K_VALUE = 40
MAX_K = 100
MAX_TOKENS = 1000
SIZE = 5
TEMPERATURE = 0.2
TEXT_KEY = "id"
VERSION = "2024-02-01"

@dataclass
class HTTPEventConfig:
"""
The EventConfig class represents the configuration for an event.
Default values are set for the following properties which can be overridden in the payload message.
"""

DEFAULT_ATTRIBUTES = ["accession_number", "alternate_title", "api_link", "canonical_link", "caption", "collection",
"contributor", "date_created", "date_created_edtf", "description", "genre", "id", "identifier",
"keywords", "language", "notes", "physical_description_material", "physical_description_size",
"provenance", "publisher", "rights_statement", "subject", "table_of_contents", "thumbnail",
"title", "visibility", "work_type"]

api_token: ApiToken = field(init=False)
attributes: List[str] = field(init=False)
azure_endpoint: str = field(init=False)
azure_resource_name: str = field(init=False)
debug_mode: bool = field(init=False)
deployment_name: str = field(init=False)
document_prompt: ChatPromptTemplate = field(init=False)
event: dict = field(default_factory=dict)
is_logged_in: bool = field(init=False)
k: int = field(init=False)
max_tokens: int = field(init=False)
openai_api_version: str = field(init=False)
payload: dict = field(default_factory=dict)
prompt_text: str = field(init=False)
prompt: ChatPromptTemplate = field(init=False)
question: str = field(init=False)
ref: str = field(init=False)
request_context: dict = field(init=False)
temperature: float = field(init=False)
size: int = field(init=False)
stream_response: bool = field(init=False)
text_key: str = field(init=False)

def __post_init__(self):
self.payload = json.loads(self.event.get("body", "{}"))
self.api_token = ApiToken(signed_token=self.payload.get("auth"))
self.attributes = self._get_attributes()
self.azure_endpoint = self._get_azure_endpoint()
self.azure_resource_name = self._get_azure_resource_name()
self.debug_mode = self._is_debug_mode_enabled()
self.deployment_name = self._get_deployment_name()
self.is_logged_in = self.api_token.is_logged_in()
self.k = self._get_k()
self.max_tokens = min(self.payload.get("max_tokens", MAX_TOKENS), MAX_TOKENS)
self.openai_api_version = self._get_openai_api_version()
self.prompt_text = self._get_prompt_text()
self.request_context = self.event.get("requestContext", {})
self.question = self.payload.get("question")
self.ref = self.payload.get("ref")
self.size = self._get_size()
self.stream_response = self.payload.get("stream_response", not self.debug_mode)
self.temperature = self._get_temperature()
self.text_key = self._get_text_key()
self.document_prompt = self._get_document_prompt()
self.prompt = ChatPromptTemplate.from_template(self.prompt_text)

def _get_payload_value_with_superuser_check(self, key, default):
if self.api_token.is_superuser():
return self.payload.get(key, default)
else:
return default

def _get_attributes_function(self):
try:
opensearch = opensearch_client()
mapping = opensearch.indices.get_mapping(index="dc-v2-work")
return list(next(iter(mapping.values()))['mappings']['properties'].keys())
except StopIteration:
return []

def _get_attributes(self):
return self._get_payload_value_with_superuser_check("attributes", self.DEFAULT_ATTRIBUTES)

def _get_azure_endpoint(self):
default = f"https://{self._get_azure_resource_name()}.openai.azure.com/"
return self._get_payload_value_with_superuser_check("azure_endpoint", default)

def _get_azure_resource_name(self):
azure_resource_name = self._get_payload_value_with_superuser_check(
"azure_resource_name", os.environ.get("AZURE_OPENAI_RESOURCE_NAME")
)
if not azure_resource_name:
raise EnvironmentError(
"Either payload must contain 'azure_resource_name' or environment variable 'AZURE_OPENAI_RESOURCE_NAME' must be set"
)
return azure_resource_name

def _get_deployment_name(self):
return self._get_payload_value_with_superuser_check(
"deployment_name", os.getenv("AZURE_OPENAI_LLM_DEPLOYMENT_ID")
)

def _get_k(self):
value = self._get_payload_value_with_superuser_check("k", K_VALUE)
return min(value, MAX_K)

def _get_openai_api_version(self):
return self._get_payload_value_with_superuser_check(
"openai_api_version", VERSION
)

def _get_prompt_text(self):
return self._get_payload_value_with_superuser_check("prompt", prompt_template())

def _get_size(self):
return self._get_payload_value_with_superuser_check("size", SIZE)

def _get_temperature(self):
return self._get_payload_value_with_superuser_check("temperature", TEMPERATURE)

def _get_text_key(self):
return self._get_payload_value_with_superuser_check("text_key", TEXT_KEY)

def _get_document_prompt(self):
return ChatPromptTemplate.from_template(document_template(self.attributes))

def debug_message(self):
return {
"type": "debug",
"message": {
"attributes": self.attributes,
"azure_endpoint": self.azure_endpoint,
"deployment_name": self.deployment_name,
"k": self.k,
"openai_api_version": self.openai_api_version,
"prompt": self.prompt_text,
"question": self.question,
"ref": self.ref,
"size": self.ref,
"temperature": self.temperature,
"text_key": self.text_key,
},
}

def setup_llm_request(self):
self._setup_vector_store()
self._setup_chat_client()

def _setup_vector_store(self):
self.opensearch = opensearch_vector_store()

def _setup_chat_client(self):
self.client = openai_chat_client(
azure_deployment=self.deployment_name,
azure_endpoint=self.azure_endpoint,
openai_api_version=self.openai_api_version,
max_tokens=self.max_tokens
)

def _is_debug_mode_enabled(self):
debug = self.payload.get("debug", False)
return debug and self.api_token.is_superuser()

def _to_bool(self, val):
"""Converts a value to boolean. If the value is a string, it considers
"", "no", "false", "0" as False. Otherwise, it returns the boolean of the value.
"""
if isinstance(val, str):
return val.lower() not in ["", "no", "false", "0"]
return bool(val)
42 changes: 42 additions & 0 deletions chat/template.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,48 @@ Resources:
Resource: !Sub "${ChatMetricsLog.Arn}:*"
#* Metadata:
#* BuildMethod: nodejs20.x
ChatSyncFunction:
Type: AWS::Serverless::Function
Properties:
CodeUri: ./src
Runtime: python3.10
Architectures:
- x86_64
#* Layers:
#* - !Ref ChatDependencies
MemorySize: 1024
Handler: handlers/chat_sync.handler
Timeout: 300
Environment:
Variables:
API_TOKEN_SECRET: !Ref ApiTokenSecret
AZURE_OPENAI_API_KEY: !Ref AzureOpenaiApiKey
AZURE_OPENAI_LLM_DEPLOYMENT_ID: !Ref AzureOpenaiLlmDeploymentId
AZURE_OPENAI_RESOURCE_NAME: !Ref AzureOpenaiResourceName
ENV_PREFIX: !Ref EnvironmentPrefix
HONEYBADGER_API_KEY: !Ref HoneybadgerApiKey
HONEYBADGER_ENVIRONMENT: !Ref HoneybadgerEnv
HONEYBADGER_REVISION: !Ref HoneybadgerRevision
METRICS_LOG_GROUP: !Ref ChatMetricsLog
OPENSEARCH_ENDPOINT: !Ref OpenSearchEndpoint
OPENSEARCH_MODEL_ID: !Ref OpenSearchModelId
FunctionUrlConfig:
AuthType: NONE
Policies:
- Statement:
- Effect: Allow
Action:
- 'es:ESHttpGet'
- 'es:ESHttpPost'
Resource: '*'
# - Statement:
# - Effect: Allow
# Action:
# - logs:CreateLogStream
# - logs:PutLogEvents
# Resource: !Sub "${ChatMetricsLog.Arn}:*"
#* Metadata:
#* BuildMethod: nodejs20.x
ChatMetricsLog:
Type: AWS::Logs::LogGroup
Properties:
Expand Down
35 changes: 35 additions & 0 deletions chat/test/handlers/test_chat_sync.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# ruff: noqa: E402

import os
import sys

sys.path.append('./src')

from unittest import mock, TestCase
from unittest.mock import patch
from handlers.chat_sync import handler
from helpers.apitoken import ApiToken

class MockContext:
def __init__(self):
self.log_stream_name = 'test'

@mock.patch.dict(
os.environ,
{
"AZURE_OPENAI_RESOURCE_NAME": "test",
},
)
class TestHandler(TestCase):
def test_handler_unauthorized(self):
self.assertEqual(handler({"body": '{ "question": "Question?"}'}, MockContext()), {'body': 'Unauthorized', 'statusCode': 401})

@patch.object(ApiToken, 'is_logged_in')
def test_no_question(self, mock_is_logged_in):
mock_is_logged_in.return_value = True
self.assertEqual(handler({"body": '{ "question": ""}'}, MockContext()), {'statusCode': 400, 'body': 'Question cannot be blank'})

@patch.object(ApiToken, 'is_logged_in')
def test_handler_success(self, mock_is_logged_in):
mock_is_logged_in.return_value = True
self.assertEqual(handler({"body": '{"question": "Question?"}'}, MockContext()), {'statusCode': 200})
Loading

0 comments on commit 5597f6a

Please sign in to comment.