Skip to content

Commit

Permalink
Use the BedrockChat LLM instead of AzureOpenAI
Browse files Browse the repository at this point in the history
  • Loading branch information
mbklein committed May 3, 2024
1 parent 8011979 commit 35be9dc
Show file tree
Hide file tree
Showing 6 changed files with 38 additions and 89 deletions.
52 changes: 15 additions & 37 deletions chat/src/event_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from setup import (
opensearch_client,
opensearch_vector_store,
openai_chat_client,
bedrock_chat_client,
)
from typing import List
from handlers.streaming_socket_callback_handler import StreamingSocketCallbackHandler
Expand Down Expand Up @@ -38,15 +38,13 @@ class EventConfig:

api_token: ApiToken = field(init=False)
attributes: List[str] = field(init=False)
azure_endpoint: str = field(init=False)
azure_resource_name: str = field(init=False)
debug_mode: bool = field(init=False)
deployment_name: str = field(init=False)
model_id: str = field(init=False)
document_prompt: PromptTemplate = field(init=False)
event: dict = field(default_factory=dict)
index_name: str = field(init=False)
is_logged_in: bool = field(init=False)
k: int = field(init=False)
openai_api_version: str = field(init=False)
payload: dict = field(default_factory=dict)
prompt_text: str = field(init=False)
prompt: PromptTemplate = field(init=False)
Expand All @@ -61,13 +59,11 @@ def __post_init__(self):
self.payload = json.loads(self.event.get("body", "{}"))
self.api_token = ApiToken(signed_token=self.payload.get("auth"))
self.attributes = self._get_attributes()
self.azure_endpoint = self._get_azure_endpoint()
self.azure_resource_name = self._get_azure_resource_name()
self.debug_mode = self._is_debug_mode_enabled()
self.deployment_name = self._get_deployment_name()
self.index_name = self._get_opensearch_index()
self.model_id = self._get_model_id()
self.is_logged_in = self.api_token.is_logged_in()
self.k = self._get_k()
self.openai_api_version = self._get_openai_api_version()
self.prompt_text = self._get_prompt_text()
self.request_context = self.event.get("requestContext", {})
self.question = self.payload.get("question")
Expand All @@ -88,42 +84,28 @@ def _get_payload_value_with_superuser_check(self, key, default):
def _get_attributes_function(self):
try:
opensearch = opensearch_client()
mapping = opensearch.indices.get_mapping(index="dc-v2-work")
mapping = opensearch.indices.get_mapping(index=self._get_opensearch_index())
return list(next(iter(mapping.values()))['mappings']['properties'].keys())
except StopIteration:
return []

def _get_attributes(self):
return self._get_payload_value_with_superuser_check("attributes", self.DEFAULT_ATTRIBUTES)

def _get_azure_endpoint(self):
default = f"https://{self._get_azure_resource_name()}.openai.azure.com/"
return self._get_payload_value_with_superuser_check("azure_endpoint", default)

def _get_azure_resource_name(self):
azure_resource_name = self._get_payload_value_with_superuser_check(
"azure_resource_name", os.environ.get("AZURE_OPENAI_RESOURCE_NAME")
)
if not azure_resource_name:
raise EnvironmentError(
"Either payload must contain 'azure_resource_name' or environment variable 'AZURE_OPENAI_RESOURCE_NAME' must be set"
)
return azure_resource_name

def _get_deployment_name(self):
def _get_model_id(self):
return self._get_payload_value_with_superuser_check(
"deployment_name", os.getenv("AZURE_OPENAI_LLM_DEPLOYMENT_ID")
"model_id", os.getenv("AI_MODEL_ID")
)

def _get_k(self):
value = self._get_payload_value_with_superuser_check("k", K_VALUE)
return min(value, MAX_K)

def _get_openai_api_version(self):
def _get_opensearch_index(self):
return self._get_payload_value_with_superuser_check(
"openai_api_version", VERSION
"index", os.getenv("INDEX_NAME")
)

def _get_prompt_text(self):
return self._get_payload_value_with_superuser_check("prompt", prompt_template())

Expand All @@ -144,10 +126,8 @@ def debug_message(self):
"type": "debug",
"message": {
"attributes": self.attributes,
"azure_endpoint": self.azure_endpoint,
"deployment_name": self.deployment_name,
"model_id": self.model_id,
"k": self.k,
"openai_api_version": self.openai_api_version,
"prompt": self.prompt_text,
"question": self.question,
"ref": self.ref,
Expand All @@ -173,13 +153,11 @@ def setup_llm_request(self):
self._setup_chain()

def _setup_vector_store(self):
self.opensearch = opensearch_vector_store()
self.opensearch = opensearch_vector_store(index_name=self.index_name)

def _setup_chat_client(self):
self.client = openai_chat_client(
deployment_name=self.deployment_name,
openai_api_base=self.azure_endpoint,
openai_api_version=self.openai_api_version,
self.client = bedrock_chat_client(
model_id=self.model_id,
callbacks=[StreamingSocketCallbackHandler(self.socket, self.debug_mode)],
streaming=True,
)
Expand Down
10 changes: 4 additions & 6 deletions chat/src/helpers/response.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from helpers.metrics import token_usage
from openai.error import InvalidRequestError

def base_response(config, response):
return {"answer": response["output_text"], "ref": config.ref}
Expand All @@ -9,11 +8,9 @@ def debug_response(config, response, original_question):
response_base = base_response(config, response)
debug_info = {
"attributes": config.attributes,
"azure_endpoint": config.azure_endpoint,
"deployment_name": config.deployment_name,
"model_id": config.model_id,
"is_superuser": config.api_token.is_superuser(),
"k": config.k,
"openai_api_version": config.openai_api_version,
"prompt": config.prompt_text,
"ref": config.ref,
"temperature": config.temperature,
Expand All @@ -35,7 +32,8 @@ def get_and_send_original_question(config, docs):
"question": config.question,
"source_documents": doc_response,
}
config.socket.send(original_question)
if (config.socket):
config.socket.send(original_question)
return original_question

def extract_prompt_value(v):
Expand Down Expand Up @@ -67,7 +65,7 @@ def prepare_response(config):
prepared_response = debug_response(config, response, original_question)
else:
prepared_response = base_response(config, response)
except InvalidRequestError as err:
except Exception as err:
prepared_response = {
"question": config.question,
"error": str(err),
Expand Down
2 changes: 1 addition & 1 deletion chat/src/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Runtime Dependencies
boto3~=1.34.13
langchain~=0.1.8
langchain
langchain-community
openai~=0.27.8
opensearch-py
Expand Down
13 changes: 5 additions & 8 deletions chat/src/setup.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from langchain_community.chat_models import AzureChatOpenAI
from content_handler import ContentHandler
from langchain_community.chat_models import BedrockChat
from handlers.opensearch_neural_search import OpenSearchNeuralSearch
from opensearchpy import OpenSearch, RequestsHttpConnection
from requests_aws4auth import AWS4Auth
Expand All @@ -10,14 +11,10 @@ def prefix(value):
env_prefix = None if env_prefix == "" else env_prefix
return '-'.join(filter(None, [env_prefix, value]))

def openai_chat_client(**kwargs):
return AzureChatOpenAI(
openai_api_key=os.getenv("AZURE_OPENAI_API_KEY"),
**kwargs,
)
def bedrock_chat_client(model_id=os.getenv("AI_MODEL_ID"), region_name=os.getenv("AWS_REGION"), **kwargs):
return BedrockChat(model_id=model_id, region_name=region_name, **kwargs)

def opensearch_client(region_name=os.getenv("AWS_REGION")):
print(region_name)
session = boto3.Session(region_name=region_name)
awsauth = AWS4Auth(region=region_name, service="es", refreshable_credentials=session.get_credentials())
endpoint = os.getenv("OPENSEARCH_ENDPOINT")
Expand All @@ -29,7 +26,7 @@ def opensearch_client(region_name=os.getenv("AWS_REGION")):
http_auth=awsauth,
)

def opensearch_vector_store(region_name=os.getenv("AWS_REGION")):
def opensearch_vector_store(region_name=os.getenv("AWS_REGION"), index_name="dc-v2-work"):
session = boto3.Session(region_name=region_name)
awsauth = AWS4Auth(region=region_name, service="es", refreshable_credentials=session.get_credentials())

Expand Down
22 changes: 10 additions & 12 deletions chat/template.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,12 @@ AWSTemplateFormatVersion: "2010-09-09"
Transform: AWS::Serverless-2016-10-31
Description: Websocket Chat API for dc-api-v2
Parameters:
AIModelId:
Type: String
Description: Amazon Bedrock Model
ApiTokenSecret:
Type: String
Description: Secret Key for Encrypting JWTs (must match IIIF server)
AzureOpenaiApiKey:
Type: String
Description: Azure OpenAI API Key
AzureOpenaiLlmDeploymentId:
Type: String
Description: Azure OpenAI LLM Deployment ID
AzureOpenaiResourceName:
Type: String
Description: Azure OpenAI Resource Name
EnvironmentPrefix:
Type: String
Description: Prefix for Index names
Expand Down Expand Up @@ -198,10 +192,8 @@ Resources:
Timeout: 300
Environment:
Variables:
AI_MODEL_ID: !Ref AIModelId
API_TOKEN_SECRET: !Ref ApiTokenSecret
AZURE_OPENAI_API_KEY: !Ref AzureOpenaiApiKey
AZURE_OPENAI_LLM_DEPLOYMENT_ID: !Ref AzureOpenaiLlmDeploymentId
AZURE_OPENAI_RESOURCE_NAME: !Ref AzureOpenaiResourceName
ENV_PREFIX: !Ref EnvironmentPrefix
OPENSEARCH_ENDPOINT: !Ref OpenSearchEndpoint
OPENSEARCH_MODEL_ID: !Ref OpenSearchModelId
Expand All @@ -218,6 +210,12 @@ Resources:
- 'es:ESHttpGet'
- 'es:ESHttpPost'
Resource: '*'
- Statement:
- Effect: Allow
Action:
- 'bedrock:InvokeModel'
- 'bedrock:InvokeModelWithResponseStream'
Resource: arn:aws:bedrock:*::foundation-model/*
Metadata:
BuildMethod: nodejs18.x
Deployment:
Expand Down
28 changes: 3 additions & 25 deletions chat/test/test_event_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,39 +8,22 @@
from unittest import TestCase, mock


class TestEventConfigWithoutAzureResource(TestCase):
def test_requires_an_azure_resource(self):
with self.assertRaises(EnvironmentError):
EventConfig()


@mock.patch.dict(
os.environ,
{
"AZURE_OPENAI_RESOURCE_NAME": "test",
"AI_MODEL_ID": "test",
},
)
class TestEventConfig(TestCase):
def test_fetches_attributes_from_vector_database(self):
os.environ.pop("AZURE_OPENAI_RESOURCE_NAME", None)
with self.assertRaises(EnvironmentError):
EventConfig()

def test_defaults(self):
actual = EventConfig(event={"body": json.dumps({"attributes": ["title"]})})
expected_defaults = {"azure_endpoint": "https://test.openai.azure.com/"}
self.assertEqual(actual.azure_endpoint, expected_defaults["azure_endpoint"])

def test_attempt_override_without_superuser_status(self):
actual = EventConfig(
event={
"body": json.dumps(
{
"azure_resource_name": "new_name_for_test",
"attributes": ["title", "subject", "date_created"],
"index": "testIndex",
"k": 100,
"openai_api_version": "2024-01-01",
"model_id": "model_override",
"question": "test question",
"ref": "test ref",
"temperature": 0.9,
Expand All @@ -51,20 +34,15 @@ def test_attempt_override_without_superuser_status(self):
)
expected_output = {
"attributes": EventConfig.DEFAULT_ATTRIBUTES,
"azure_endpoint": "https://test.openai.azure.com/",
"model_id": "test",
"k": 5,
"openai_api_version": "2023-07-01-preview",
"question": "test question",
"ref": "test ref",
"temperature": 0.2,
"text_key": "id",
}
self.assertEqual(actual.azure_endpoint, expected_output["azure_endpoint"])
self.assertEqual(actual.attributes, expected_output["attributes"])
self.assertEqual(actual.k, expected_output["k"])
self.assertEqual(
actual.openai_api_version, expected_output["openai_api_version"]
)
self.assertEqual(actual.question, expected_output["question"])
self.assertEqual(actual.ref, expected_output["ref"])
self.assertEqual(actual.temperature, expected_output["temperature"])
Expand Down

0 comments on commit 35be9dc

Please sign in to comment.