Skip to content

Commit

Permalink
Update chat handler for Opensearch
Browse files Browse the repository at this point in the history
Update permissions on chat websocket function

Add AWS4Auth to opensearch client

Tweak EventConfig to make chat work with OpenSearch
  • Loading branch information
kdid authored and mbklein committed Feb 26, 2024
1 parent 0f2f773 commit 8579d09
Show file tree
Hide file tree
Showing 16 changed files with 194 additions and 143 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,8 @@ $RECYCLE.BIN/
/docs/docs/spec/openapi.json
/docs/site

.venv

.vscode
/samconfig.toml
/samconfig.yaml
Expand Down
2 changes: 2 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ cover-html-python: deps-python
cd chat && export SKIP_WEAVIATE_SETUP=True && coverage run --source=src -m unittest -v && coverage html --skip-empty
style-python: deps-python
cd chat && ruff check .
style-python-fix: deps-python
cd chat && ruff check --fix .
test-python: deps-python
cd chat && export SKIP_WEAVIATE_SETUP=True && PYTHONPATH=src:test && python -m unittest discover -v
python-version:
Expand Down
7 changes: 5 additions & 2 deletions chat/dependencies/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
boto3~=1.34.13
langchain~=0.0.208
langchain~=0.1.8
langchain-community
openai~=0.27.8
opensearch-py
pyjwt~=2.6.0
python-dotenv~=1.0.0
requests
requests-aws4auth
tiktoken~=0.4.0
weaviate-client~=3.19.2
wheel~=0.40.0
36 changes: 36 additions & 0 deletions chat/src/content_handler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import json
from typing import Dict, List
from langchain_community.embeddings.sagemaker_endpoint import EmbeddingsContentHandler

class ContentHandler(EmbeddingsContentHandler):
content_type = "application/json"
accepts = "application/json"

def transform_input(self, inputs: list[str], model_kwargs: Dict) -> bytes:
"""
Transforms the input into bytes that can be consumed by SageMaker endpoint.
Args:
inputs: List of input strings.
model_kwargs: Additional keyword arguments to be passed to the endpoint.
Returns:
The transformed bytes input.
"""
# Example: inference.py expects a JSON string with a "inputs" key:
input_str = json.dumps({"inputs": inputs, **model_kwargs})
return input_str.encode("utf-8")

def transform_output(self, output: bytes) -> List[List[float]]:
"""
Transforms the bytes output from the endpoint into a list of embeddings.
Args:
output: The bytes output from SageMaker endpoint.
Returns:
The transformed output - list of embeddings
Note:
The length of the outer list is the number of input strings.
The length of the inner lists is the embedding dimension.
"""
# Example: inference.py returns a JSON string with the list of
# embeddings in a "vectors" key:
response_json = json.loads(output.read().decode("utf-8"))
return [response_json["embedding"]]
91 changes: 42 additions & 49 deletions chat/src/event_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
from langchain.chains.qa_with_sources import load_qa_with_sources_chain
from langchain.prompts import PromptTemplate
from setup import (
weaviate_client,
weaviate_vector_store,
opensearch_client,
opensearch_vector_store,
openai_chat_client,
)
from typing import List
Expand All @@ -15,24 +15,27 @@
from helpers.prompts import document_template, prompt_template
from websocket import Websocket


CHAIN_TYPE = "stuff"
DOCUMENT_VARIABLE_NAME = "context"
INDEX_NAME = "DCWork"
K_VALUE = 10
K_VALUE = 5
MAX_K = 100
TEMPERATURE = 0.2
TEXT_KEY = "title"
VERSION = "2023-07-01-preview"


@dataclass
class EventConfig:
"""
The EventConfig class represents the configuration for an event.
Default values are set for the following properties which can be overridden in the payload message.
"""

DEFAULT_ATTRIBUTES = ["accession_number", "alternate_title", "api_link", "canonical_link", "caption", "collection",
"contributor", "date_created", "date_created_edtf", "description", "genre", "id", "identifier",
"keywords", "language", "notes", "physical_description_material", "physical_description_size",
"provenance", "publisher", "rights_statement", "subject", "table_of_contents", "thumbnail",
"title", "visibility", "work_type"]

api_token: ApiToken = field(init=False)
attributes: List[str] = field(init=False)
azure_endpoint: str = field(init=False)
Expand All @@ -41,7 +44,6 @@ class EventConfig:
deployment_name: str = field(init=False)
document_prompt: PromptTemplate = field(init=False)
event: dict = field(default_factory=dict)
index_name: str = field(init=False)
is_logged_in: bool = field(init=False)
k: int = field(init=False)
openai_api_version: str = field(init=False)
Expand All @@ -54,7 +56,7 @@ class EventConfig:
temperature: float = field(init=False)
socket: Websocket = field(init=False, default=None)
text_key: str = field(init=False)

def __post_init__(self):
self.payload = json.loads(self.event.get("body", "{}"))
self.api_token = ApiToken(signed_token=self.payload.get("auth"))
Expand All @@ -64,7 +66,6 @@ def __post_init__(self):
self.azure_endpoint = self._get_azure_endpoint()
self.debug_mode = self._is_debug_mode_enabled()
self.deployment_name = self._get_deployment_name()
self.index_name = self._get_index_name()
self.is_logged_in = self.api_token.is_logged_in()
self.k = self._get_k()
self.openai_api_version = self._get_openai_api_version()
Expand All @@ -74,75 +75,70 @@ def __post_init__(self):
self.ref = self.payload.get("ref")
self.temperature = self._get_temperature()
self.text_key = self._get_text_key()
self.attributes = self._get_attributes()
self.document_prompt = self._get_document_prompt()
self.prompt = PromptTemplate(template=self.prompt_text, input_variables=["question", "context"])
self.prompt = PromptTemplate(
template=self.prompt_text, input_variables=["question", "context"]
)

def _get_payload_value_with_superuser_check(self, key, default):
if self.api_token.is_superuser():
return self.payload.get(key, default)
else:
return default

def _get_attributes_function(self):
try:
opensearch = opensearch_client()
mapping = opensearch.indices.get_mapping(index="dc-v2-work")
return list(next(iter(mapping.values()))['mappings']['properties'].keys())
except StopIteration:
return []

def _get_attributes(self):
return self._get_payload_value_with_superuser_check("attributes", self.DEFAULT_ATTRIBUTES)
# return self._get_payload_value_with_superuser_check("attributes", self._get_attributes_function())

def _get_azure_endpoint(self):
default = f"https://{self._get_azure_resource_name()}.openai.azure.com/"
return self._get_payload_value_with_superuser_check("azure_endpoint", default)

def _get_azure_resource_name(self):
azure_resource_name = self._get_payload_value_with_superuser_check("azure_resource_name", os.environ.get("AZURE_OPENAI_RESOURCE_NAME"))
azure_resource_name = self._get_payload_value_with_superuser_check(
"azure_resource_name", os.environ.get("AZURE_OPENAI_RESOURCE_NAME")
)
if not azure_resource_name:
raise EnvironmentError(
"Either payload must contain 'azure_resource_name' or environment variable 'AZURE_OPENAI_RESOURCE_NAME' must be set"
)
return azure_resource_name

def _get_deployment_name(self):
return self._get_payload_value_with_superuser_check("deployment_name", os.getenv("AZURE_OPENAI_LLM_DEPLOYMENT_ID"))

def _get_index_name(self):
return self._get_payload_value_with_superuser_check("index", INDEX_NAME)
return self._get_payload_value_with_superuser_check(
"deployment_name", os.getenv("AZURE_OPENAI_LLM_DEPLOYMENT_ID")
)

def _get_k(self):
value = self._get_payload_value_with_superuser_check("k", K_VALUE)
return min(value, MAX_K)

def _get_openai_api_version(self):
return self._get_payload_value_with_superuser_check("openai_api_version", VERSION)

return self._get_payload_value_with_superuser_check(
"openai_api_version", VERSION
)

def _get_prompt_text(self):
return self._get_payload_value_with_superuser_check("prompt", prompt_template())

def _get_temperature(self):
return self._get_payload_value_with_superuser_check("temperature", TEMPERATURE)

def _get_text_key(self):
return self._get_payload_value_with_superuser_check("text_key", TEXT_KEY)

def _get_attributes(self):
attributes = [
item
for item in self._get_request_attributes()
if item not in [self._get_text_key(), "source", "full_text"]
]
return attributes

def _get_request_attributes(self):
if os.getenv("SKIP_WEAVIATE_SETUP"):
return []

attributes = self._get_payload_value_with_superuser_check("attributes", [])
if attributes:
return attributes
else:
client = weaviate_client()
schema = client.schema.get(self._get_index_name())
names = [prop["name"] for prop in schema.get("properties")]
return names

def _get_document_prompt(self):
return PromptTemplate(
template=document_template(self.attributes),
input_variables=["page_content", "source"] + self.attributes,
input_variables=["title", "id"] + self.attributes,
)

def debug_message(self):
Expand All @@ -152,7 +148,6 @@ def debug_message(self):
"attributes": self.attributes,
"azure_endpoint": self.azure_endpoint,
"deployment_name": self.deployment_name,
"index": self.index_name,
"k": self.k,
"openai_api_version": self.openai_api_version,
"prompt": self.prompt_text,
Expand All @@ -167,7 +162,9 @@ def setup_websocket(self, socket=None):
if socket is None:
connection_id = self.request_context.get("connectionId")
endpoint_url = f'https://{self.request_context.get("domainName")}/{self.request_context.get("stage")}'
self.socket = Websocket(endpoint_url=endpoint_url, connection_id=connection_id, ref=self.ref)
self.socket = Websocket(
endpoint_url=endpoint_url, connection_id=connection_id, ref=self.ref
)
else:
self.socket = socket
return self.socket
Expand All @@ -178,11 +175,7 @@ def setup_llm_request(self):
self._setup_chain()

def _setup_vector_store(self):
self.weaviate = weaviate_vector_store(
index_name=self.index_name,
text_key=self.text_key,
attributes=self.attributes + ["source"],
)
self.opensearch = opensearch_vector_store()

def _setup_chat_client(self):
self.client = openai_chat_client(
Expand Down
13 changes: 7 additions & 6 deletions chat/src/handlers/chat.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import os
import sys
import traceback
from event_config import EventConfig
from helpers.response import prepare_response

Expand All @@ -21,9 +23,8 @@ def handler(event, _context):
config.socket.send(final_response)
return {"statusCode": 200}

except Exception as err:
if err.__class__.__name__ == "PayloadTooLargeException":
config.socket.send({"type": "error", "message": "Payload too large"})
return {"statusCode": 413, "body": "Payload too large"}
else:
raise err
except Exception:
exc_info = sys.exc_info()
err_text = ''.join(traceback.format_exception(*exc_info))
print(err_text)
return {"statusCode": 500, "body": f'Unhandled error:\n{err_text}'}
4 changes: 2 additions & 2 deletions chat/src/helpers/prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ def document_template(attributes: Optional[List[str]] = None) -> str:
if attributes is None:
attributes = []
lines = (
["Content: {page_content}", "Metadata:"]
["Content: {title}", "Metadata:"]
+ [f" {attribute}: {{{attribute}}}" for attribute in attributes]
+ ["Source: {source}"]
+ ["Source: {id}"]
)
return "\n".join(lines)
23 changes: 17 additions & 6 deletions chat/src/helpers/response.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from helpers.metrics import token_usage
from openai.error import InvalidRequestError


def base_response(config, response):
return {"answer": response["output_text"], "ref": config.ref}

Expand All @@ -12,7 +11,6 @@ def debug_response(config, response, original_question):
"attributes": config.attributes,
"azure_endpoint": config.azure_endpoint,
"deployment_name": config.deployment_name,
"index": config.index_name,
"is_superuser": config.api_token.is_superuser(),
"k": config.k,
"openai_api_version": config.openai_api_version,
Expand All @@ -26,19 +24,32 @@ def debug_response(config, response, original_question):


def get_and_send_original_question(config, docs):
doc_response = [doc.__dict__ for doc in docs]
doc_response = []
for doc in docs:
doc_dict = doc.__dict__
metadata = doc_dict.get('metadata', {})
new_doc = {key: extract_prompt_value(metadata.get(key)) for key in config.attributes if key in metadata}
doc_response.append(new_doc)

original_question = {
"question": config.question,
"source_documents": doc_response,
}
config.socket.send(original_question)
return original_question


def extract_prompt_value(v):
if isinstance(v, list):
return [extract_prompt_value(item) for item in v]
elif isinstance(v, dict) and 'label' in v:
return [v.get('label')]
else:
return v

def prepare_response(config):
try:
docs = config.weaviate.similarity_search(
config.question, k=config.k, additional="certainty"
docs = config.opensearch.similarity_search(
config.question, k=config.k, vector_field="embedding", text_field="id"
)
original_question = get_and_send_original_question(config, docs)
response = config.chain({"question": config.question, "input_documents": docs})
Expand Down
7 changes: 5 additions & 2 deletions chat/src/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
# Runtime Dependencies
boto3~=1.34.13
langchain~=0.0.208
langchain~=0.1.8
langchain-community
openai~=0.27.8
opensearch-py
pyjwt~=2.6.0
python-dotenv~=1.0.0
requests
requests-aws4auth
tiktoken~=0.4.0
weaviate-client~=3.19.2
wheel~=0.40.0

# Dev/Test Dependencies
Expand Down
Loading

0 comments on commit 8579d09

Please sign in to comment.