Skip to content

Commit

Permalink
Implement mock response for chat tests
Browse files Browse the repository at this point in the history
  • Loading branch information
kdid authored and mbklein committed Aug 29, 2024
1 parent cd7690b commit 9eafb82
Show file tree
Hide file tree
Showing 6 changed files with 51 additions and 23 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/test-python.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ jobs:
env:
AWS_ACCESS_KEY_ID: ci
AWS_SECRET_ACCESS_KEY: ci
SKIP_WEAVIATE_SETUP: 'True'
SKIP_LLM_REQUEST: 'True'
steps:
- uses: actions/checkout@v3
- uses: actions/setup-python@v4
Expand Down
6 changes: 3 additions & 3 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -52,15 +52,15 @@ test-node: deps-node
deps-python:
cd chat/src && pip install -r requirements.txt && pip install -r requirements-dev.txt
cover-python: deps-python
cd chat && export SKIP_WEAVIATE_SETUP=True && coverage run --source=src -m unittest -v && coverage report --skip-empty
cd chat && export SKIP_LLM_REQUEST=True && coverage run --source=src -m unittest -v && coverage report --skip-empty
cover-html-python: deps-python
cd chat && export SKIP_WEAVIATE_SETUP=True && coverage run --source=src -m unittest -v && coverage html --skip-empty
cd chat && export SKIP_LLM_REQUEST=True && coverage run --source=src -m unittest -v && coverage html --skip-empty
style-python: deps-python
cd chat && ruff check .
style-python-fix: deps-python
cd chat && ruff check --fix .
test-python: deps-python
cd chat && SKIP_WEAVIATE_SETUP=True PYTHONPATH=src:test python -m unittest discover -v
cd chat && SKIP_LLM_REQUEST=True PYTHONPATH=src:test python -m unittest discover -v
python-version:
cd chat && python --version
build: .aws-sam/build.toml
Expand Down
22 changes: 11 additions & 11 deletions chat/src/handlers/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

RESPONSE_TYPES = {
"base": ["answer", "ref"],
"debug": ["answer", "attributes", "azure_endpoint", "deployment_name", "is_dev_team", "is_superuser", "k", "openai_api_version", "prompt", "question", "ref", "temperature", "text_key", "token_counts"],
"debug": ["answer", "attributes", "azure_endpoint", "deployment_name", "is_superuser", "k", "openai_api_version", "prompt", "question", "ref", "temperature", "text_key", "token_counts"],
"log": ["answer", "deployment_name", "is_superuser", "k", "openai_api_version", "prompt", "question", "ref", "size", "source_documents", "temperature", "token_counts", "is_dev_team"],
"error": ["question", "error", "source_documents"]
}
Expand All @@ -31,20 +31,20 @@ def handler(event, context):
return {"statusCode": 400, "body": "Question cannot be blank"}

debug_message = config.debug_message()
print(f"Debug mode: {config.debug_mode}")
if config.debug_mode:
config.socket.send(debug_message)

if not os.getenv("SKIP_WEAVIATE_SETUP"):
if not os.getenv("SKIP_LLM_REQUEST"):
config.setup_llm_request()
response = Response(config)
final_response = response.prepare_response()
if "error" in final_response:
logging.error(f'Error: {final_response["error"]}')
config.socket.send({"type": "error", "message": "Internal Server Error"})
return {"statusCode": 500, "body": "Internal Server Error"}
else:
config.socket.send(reshape_response(final_response, 'debug' if config.debug_mode else 'base'))

response = Response(config)
final_response = response.prepare_response()
if "error" in final_response:
logging.error(f'Error: {final_response["error"]}')
config.socket.send({"type": "error", "message": "Internal Server Error"})
return {"statusCode": 500, "body": "Internal Server Error"}
else:
config.socket.send(reshape_response(final_response, 'debug' if config.debug_mode else 'base'))

log_group = os.getenv('METRICS_LOG_GROUP')
log_stream = context.log_stream_name
Expand Down
2 changes: 1 addition & 1 deletion chat/src/handlers/chat_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def handler(event, context):
if config.question is None or config.question == "":
return {"statusCode": 400, "body": "Question cannot be blank"}

if not os.getenv("SKIP_WEAVIATE_SETUP"):
if not os.getenv("SKIP_LLM_REQUEST"):
config.setup_llm_request()
response = HTTPResponse(config)
final_response = response.prepare_response()
Expand Down
1 change: 1 addition & 0 deletions chat/src/helpers/apitoken.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ def __init__(self, signed_token=None):
try:
secret = os.getenv("API_TOKEN_SECRET")
self.token = jwt.decode(signed_token, secret, algorithms=["HS256"])
print(self.token)
except Exception:
self.token = ApiToken.empty_token()

Expand Down
41 changes: 34 additions & 7 deletions chat/test/handlers/test_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from unittest.mock import patch
from handlers.chat import handler
from helpers.apitoken import ApiToken
from helpers.response import Response
from websocket import Websocket
from event_config import EventConfig

Expand All @@ -25,12 +26,36 @@ class MockContext:
def __init__(self):
self.log_stream_name = 'test'

# TODO: Find a way to build a better mock response (maybe using helpers.metrics.debug_response)
def mock_response(**kwargs):
result = {
'answer': 'Answer.',
'attributes': ['accession_number', 'alternate_title', 'api_link', 'canonical_link', 'caption', 'collection', 'contributor', 'date_created', 'date_created_edtf', 'description', 'genre', 'id', 'identifier', 'keywords', 'language', 'notes', 'physical_description_material', 'physical_description_size', 'provenance', 'publisher', 'rights_statement', 'subject', 'table_of_contents', 'thumbnail', 'title', 'visibility', 'work_type'],
'azure_endpoint': 'https://nul-ai-east.openai.azure.com/',
'deployment_name': 'gpt-4o',
'is_dev_team': False,
'is_superuser': False,
'k': 10,
'openai_api_version': '2024-02-01',
'prompt': "Prompt",
'question': 'Question?',
'ref': 'ref123',
'size': 20,
'source_documents': [],
'temperature': 0.2,
'text_key': 'id',
'token_counts': {'question': 19, 'answer': 348, 'prompt': 329, 'source_documents': 10428,'total': 11124}
}
result.update(kwargs)
return result

@mock.patch.dict(
os.environ,
{
"AZURE_OPENAI_RESOURCE_NAME": "test",
},
)
@mock.patch.object(Response, "prepare_response", lambda _: mock_response())
class TestHandler(TestCase):
def test_handler_unauthorized(self):
event = {"socket": Websocket(client=MockClient(), endpoint_url="test", connection_id="test", ref="test")}
Expand All @@ -45,7 +70,7 @@ def test_handler_success(self, mock_is_logged_in):
@patch.object(ApiToken, 'is_logged_in')
@patch.object(ApiToken, 'is_superuser')
@patch.object(EventConfig, '_is_debug_mode_enabled')
def test_handler_debug_mode(self, mock_is_debug_enabled, mock_is_logged_in, mock_is_superuser):
def test_handler_debug_mode(self, mock_is_debug_enabled, mock_is_superuser, mock_is_logged_in):
mock_is_debug_enabled.return_value = True
mock_is_logged_in.return_value = True
mock_is_superuser.return_value = True
Expand All @@ -54,21 +79,23 @@ def test_handler_debug_mode(self, mock_is_debug_enabled, mock_is_logged_in, mock
event = {"socket": mock_websocket, "debug": True, "body": '{"question": "Question?"}' }
handler(event, MockContext())
response = json.loads(mock_client.received_data)
self.assertEqual(response["type"], "debug")
expected_keys = {"attributes", "azure_endpoint", "deployment_name"}
received_keys = response.keys()
self.assertTrue(expected_keys.issubset(received_keys))

@patch.object(ApiToken, 'is_logged_in')
@patch.object(ApiToken, 'is_superuser')
@patch.object(EventConfig, '_is_debug_mode_enabled')
def test_handler_debug_mode_for_superusers_only(self, mock_is_debug_enabled, mock_is_logged_in, mock_is_superuser):
mock_is_debug_enabled.return_value = True
def test_handler_debug_mode_for_superusers_only(self, mock_is_superuser, mock_is_logged_in):
mock_is_logged_in.return_value = True
mock_is_superuser.return_value = False
mock_client = MockClient()
mock_websocket = Websocket(client=mock_client, endpoint_url="test", connection_id="test", ref="test")
event = {"socket": mock_websocket, "debug": True, "body": '{"question": "Question?"}' }
event = {"socket": mock_websocket, "body": '{"question": "Question?", "debug": "true"}'}
handler(event, MockContext())
response = json.loads(mock_client.received_data)
self.assertEqual(response["type"], "error")
expected_keys = {"answer", "ref"}
received_keys = set(response.keys())
self.assertSetEqual(received_keys, expected_keys)

@patch.object(ApiToken, 'is_logged_in')
def test_handler_question_missing(self, mock_is_logged_in):
Expand Down

0 comments on commit 9eafb82

Please sign in to comment.