diff --git a/.github/workflows/test-python.yml b/.github/workflows/test-python.yml index ed9650cc..7b76990a 100644 --- a/.github/workflows/test-python.yml +++ b/.github/workflows/test-python.yml @@ -14,7 +14,7 @@ jobs: env: AWS_ACCESS_KEY_ID: ci AWS_SECRET_ACCESS_KEY: ci - SKIP_WEAVIATE_SETUP: 'True' + SKIP_LLM_REQUEST: 'True' steps: - uses: actions/checkout@v3 - uses: actions/setup-python@v4 diff --git a/Makefile b/Makefile index e2d9e6a4..4864e7ab 100644 --- a/Makefile +++ b/Makefile @@ -52,15 +52,15 @@ test-node: deps-node deps-python: cd chat/src && pip install -r requirements.txt && pip install -r requirements-dev.txt cover-python: deps-python - cd chat && export SKIP_WEAVIATE_SETUP=True && coverage run --source=src -m unittest -v && coverage report --skip-empty + cd chat && export SKIP_LLM_REQUEST=True && coverage run --source=src -m unittest -v && coverage report --skip-empty cover-html-python: deps-python - cd chat && export SKIP_WEAVIATE_SETUP=True && coverage run --source=src -m unittest -v && coverage html --skip-empty + cd chat && export SKIP_LLM_REQUEST=True && coverage run --source=src -m unittest -v && coverage html --skip-empty style-python: deps-python cd chat && ruff check . style-python-fix: deps-python cd chat && ruff check --fix . test-python: deps-python - cd chat && SKIP_WEAVIATE_SETUP=True PYTHONPATH=src:test python -m unittest discover -v + cd chat && SKIP_LLM_REQUEST=True PYTHONPATH=src:test python -m unittest discover -v python-version: cd chat && python --version build: .aws-sam/build.toml diff --git a/chat/src/event_config.py b/chat/src/event_config.py index 3b1e8ae7..28e09348 100644 --- a/chat/src/event_config.py +++ b/chat/src/event_config.py @@ -46,7 +46,9 @@ class EventConfig: deployment_name: str = field(init=False) document_prompt: ChatPromptTemplate = field(init=False) event: dict = field(default_factory=dict) + is_dev_team: bool = field(init=False) is_logged_in: bool = field(init=False) + is_superuser: bool = field(init=False) k: int = field(init=False) max_tokens: int = field(init=False) openai_api_version: str = field(init=False) @@ -70,7 +72,9 @@ def __post_init__(self): self.azure_resource_name = self._get_azure_resource_name() self.debug_mode = self._is_debug_mode_enabled() self.deployment_name = self._get_deployment_name() + self.is_dev_team = self.api_token.is_dev_team() self.is_logged_in = self.api_token.is_logged_in() + self.is_superuser = self.api_token.is_superuser() self.k = self._get_k() self.max_tokens = min(self.payload.get("max_tokens", MAX_TOKENS), MAX_TOKENS) self.openai_api_version = self._get_openai_api_version() diff --git a/chat/src/handlers/chat.py b/chat/src/handlers/chat.py index 3630bfab..443a4f29 100644 --- a/chat/src/handlers/chat.py +++ b/chat/src/handlers/chat.py @@ -13,7 +13,7 @@ RESPONSE_TYPES = { "base": ["answer", "ref"], "debug": ["answer", "attributes", "azure_endpoint", "deployment_name", "is_superuser", "k", "openai_api_version", "prompt", "question", "ref", "temperature", "text_key", "token_counts"], - "log": ["answer", "deployment_name", "is_superuser", "k", "openai_api_version", "prompt", "question", "ref", "size", "source_documents", "temperature", "token_counts"], + "log": ["answer", "deployment_name", "is_superuser", "k", "openai_api_version", "prompt", "question", "ref", "size", "source_documents", "temperature", "token_counts", "is_dev_team"], "error": ["question", "error", "source_documents"] } @@ -22,7 +22,7 @@ def handler(event, context): socket = event.get('socket', None) config.setup_websocket(socket) - if not config.is_logged_in: + if not (config.is_logged_in or config.is_superuser): config.socket.send({"type": "error", "message": "Unauthorized"}) return {"statusCode": 401, "body": "Unauthorized"} @@ -34,16 +34,17 @@ def handler(event, context): if config.debug_mode: config.socket.send(debug_message) - if not os.getenv("SKIP_WEAVIATE_SETUP"): + if not os.getenv("SKIP_LLM_REQUEST"): config.setup_llm_request() - response = Response(config) - final_response = response.prepare_response() - if "error" in final_response: - logging.error(f'Error: {final_response["error"]}') - config.socket.send({"type": "error", "message": "Internal Server Error"}) - return {"statusCode": 500, "body": "Internal Server Error"} - else: - config.socket.send(reshape_response(final_response, 'debug' if config.debug_mode else 'base')) + + response = Response(config) + final_response = response.prepare_response() + if "error" in final_response: + logging.error(f'Error: {final_response["error"]}') + config.socket.send({"type": "error", "message": "Internal Server Error"}) + return {"statusCode": 500, "body": "Internal Server Error"} + else: + config.socket.send(reshape_response(final_response, 'debug' if config.debug_mode else 'base')) log_group = os.getenv('METRICS_LOG_GROUP') log_stream = context.log_stream_name diff --git a/chat/src/handlers/chat_sync.py b/chat/src/handlers/chat_sync.py index 8166870e..dd3d88cf 100644 --- a/chat/src/handlers/chat_sync.py +++ b/chat/src/handlers/chat_sync.py @@ -17,8 +17,6 @@ } def handler(event, context): - print(f'Event: {event}') - config = HTTPEventConfig(event) if not config.is_logged_in: @@ -27,7 +25,7 @@ def handler(event, context): if config.question is None or config.question == "": return {"statusCode": 400, "body": "Question cannot be blank"} - if not os.getenv("SKIP_WEAVIATE_SETUP"): + if not os.getenv("SKIP_LLM_REQUEST"): config.setup_llm_request() response = HTTPResponse(config) final_response = response.prepare_response() diff --git a/chat/src/helpers/apitoken.py b/chat/src/helpers/apitoken.py index 46c97263..0cd03e29 100644 --- a/chat/src/helpers/apitoken.py +++ b/chat/src/helpers/apitoken.py @@ -13,6 +13,7 @@ def empty_token(cls): "iat": time, "entitlements": [], "isLoggedIn": False, + "isDevTeam": False, } def __init__(self, signed_token=None): @@ -33,3 +34,6 @@ def is_logged_in(self): def is_superuser(self): return self.token.get("isSuperUser", False) + + def is_dev_team(self): + return self.token.get("isDevTeam", False) diff --git a/chat/src/helpers/metrics.py b/chat/src/helpers/metrics.py index 9610eac7..f00abc00 100644 --- a/chat/src/helpers/metrics.py +++ b/chat/src/helpers/metrics.py @@ -8,6 +8,7 @@ def debug_response(config, response, original_question): "attributes": config.attributes, "azure_endpoint": config.azure_endpoint, "deployment_name": config.deployment_name, + "is_dev_team": config.api_token.is_dev_team(), "is_superuser": config.api_token.is_superuser(), "k": config.k, "openai_api_version": config.openai_api_version, diff --git a/chat/test/fixtures/apitoken.py b/chat/test/fixtures/apitoken.py index 08691856..5c36541d 100644 --- a/chat/test/fixtures/apitoken.py +++ b/chat/test/fixtures/apitoken.py @@ -6,3 +6,4 @@ TEST_TOKEN = ('eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJleHAiOjQ4NDM1ODY2MDYxNjUs' 'ImlhdCI6MTY4Nzg5MTM2OSwiZW50aXRsZW1lbnRzIjpbXSwiaXNMb2dnZWRJbiI6d' 'HJ1ZSwic3ViIjoidGVzdFVzZXIifQ.vIZag1pHE1YyrxsKKlakXX_44ckAvkg7xWOoA_w4x58') +DEV_TEAM_TOKEN = ('eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJleHAiOjI1MjY1OTQ2MDcsInN1YiI6ImFiYzEyMyIsImlzcyI6Im1lYWRvdyIsImlhdCI6MTcyNDk1OTUyNiwiZW50aXRsZW1lbnRzIjpbXSwiaXNMb2dnZWRJbiI6dHJ1ZSwiaXNTdXBlclVzZXIiOmZhbHNlLCJpc0RldlRlYW0iOnRydWV9.YrgNRcksnf1e0lIUdo3gdyAZR0_vUsuGzY9h6gziZbY') diff --git a/chat/test/handlers/test_chat.py b/chat/test/handlers/test_chat.py index 532c1bc4..a2cd93e8 100644 --- a/chat/test/handlers/test_chat.py +++ b/chat/test/handlers/test_chat.py @@ -10,6 +10,7 @@ from unittest.mock import patch from handlers.chat import handler from helpers.apitoken import ApiToken +from helpers.response import Response from websocket import Websocket from event_config import EventConfig @@ -25,12 +26,36 @@ class MockContext: def __init__(self): self.log_stream_name = 'test' +# TODO: Find a way to build a better mock response (maybe using helpers.metrics.debug_response) +def mock_response(**kwargs): + result = { + 'answer': 'Answer.', + 'attributes': ['accession_number', 'alternate_title', 'api_link', 'canonical_link', 'caption', 'collection', 'contributor', 'date_created', 'date_created_edtf', 'description', 'genre', 'id', 'identifier', 'keywords', 'language', 'notes', 'physical_description_material', 'physical_description_size', 'provenance', 'publisher', 'rights_statement', 'subject', 'table_of_contents', 'thumbnail', 'title', 'visibility', 'work_type'], + 'azure_endpoint': 'https://nul-ai-east.openai.azure.com/', + 'deployment_name': 'gpt-4o', + 'is_dev_team': False, + 'is_superuser': False, + 'k': 10, + 'openai_api_version': '2024-02-01', + 'prompt': "Prompt", + 'question': 'Question?', + 'ref': 'ref123', + 'size': 20, + 'source_documents': [], + 'temperature': 0.2, + 'text_key': 'id', + 'token_counts': {'question': 19, 'answer': 348, 'prompt': 329, 'source_documents': 10428,'total': 11124} + } + result.update(kwargs) + return result + @mock.patch.dict( os.environ, { "AZURE_OPENAI_RESOURCE_NAME": "test", }, ) +@mock.patch.object(Response, "prepare_response", lambda _: mock_response()) class TestHandler(TestCase): def test_handler_unauthorized(self): event = {"socket": Websocket(client=MockClient(), endpoint_url="test", connection_id="test", ref="test")} @@ -45,7 +70,7 @@ def test_handler_success(self, mock_is_logged_in): @patch.object(ApiToken, 'is_logged_in') @patch.object(ApiToken, 'is_superuser') @patch.object(EventConfig, '_is_debug_mode_enabled') - def test_handler_debug_mode(self, mock_is_debug_enabled, mock_is_logged_in, mock_is_superuser): + def test_handler_debug_mode(self, mock_is_debug_enabled, mock_is_superuser, mock_is_logged_in): mock_is_debug_enabled.return_value = True mock_is_logged_in.return_value = True mock_is_superuser.return_value = True @@ -54,21 +79,23 @@ def test_handler_debug_mode(self, mock_is_debug_enabled, mock_is_logged_in, mock event = {"socket": mock_websocket, "debug": True, "body": '{"question": "Question?"}' } handler(event, MockContext()) response = json.loads(mock_client.received_data) - self.assertEqual(response["type"], "debug") + expected_keys = {"attributes", "azure_endpoint", "deployment_name"} + received_keys = response.keys() + self.assertTrue(expected_keys.issubset(received_keys)) @patch.object(ApiToken, 'is_logged_in') @patch.object(ApiToken, 'is_superuser') - @patch.object(EventConfig, '_is_debug_mode_enabled') - def test_handler_debug_mode_for_superusers_only(self, mock_is_debug_enabled, mock_is_logged_in, mock_is_superuser): - mock_is_debug_enabled.return_value = True + def test_handler_debug_mode_for_superusers_only(self, mock_is_superuser, mock_is_logged_in): mock_is_logged_in.return_value = True mock_is_superuser.return_value = False mock_client = MockClient() mock_websocket = Websocket(client=mock_client, endpoint_url="test", connection_id="test", ref="test") - event = {"socket": mock_websocket, "debug": True, "body": '{"question": "Question?"}' } + event = {"socket": mock_websocket, "body": '{"question": "Question?", "debug": "true"}'} handler(event, MockContext()) response = json.loads(mock_client.received_data) - self.assertEqual(response["type"], "error") + expected_keys = {"answer", "ref"} + received_keys = set(response.keys()) + self.assertSetEqual(received_keys, expected_keys) @patch.object(ApiToken, 'is_logged_in') def test_handler_question_missing(self, mock_is_logged_in): diff --git a/chat/test/helpers/test_apitoken.py b/chat/test/helpers/test_apitoken.py index a330f56a..e23a1646 100644 --- a/chat/test/helpers/test_apitoken.py +++ b/chat/test/helpers/test_apitoken.py @@ -1,15 +1,16 @@ # ruff: noqa: E402 import os import sys + sys.path.append('./src') from helpers.apitoken import ApiToken -from test.fixtures.apitoken import SUPER_TOKEN, TEST_SECRET, TEST_TOKEN +from test.fixtures.apitoken import DEV_TEAM_TOKEN, SUPER_TOKEN, TEST_SECRET, TEST_TOKEN from unittest import mock, TestCase - +@mock.patch.dict(os.environ, {"DEV_TEAM_NET_IDS": "abc123"}) @mock.patch.dict(os.environ, {"API_TOKEN_SECRET": TEST_SECRET}) class TestFunction(TestCase): def test_empty_token(self): @@ -29,6 +30,11 @@ def test_superuser_token(self): self.assertTrue(subject.is_logged_in()) self.assertTrue(subject.is_superuser()) + def test_devteam_token(self): + subject = ApiToken(DEV_TEAM_TOKEN) + self.assertIsInstance(subject, ApiToken) + self.assertTrue(subject.is_dev_team()) + def test_invalid_token(self): subject = ApiToken("INVALID_TOKEN") self.assertIsInstance(subject, ApiToken) diff --git a/node/src/api/api-token.js b/node/src/api/api-token.js index 51a6f490..4ea75c38 100644 --- a/node/src/api/api-token.js +++ b/node/src/api/api-token.js @@ -1,4 +1,8 @@ -const { apiTokenSecret, dcApiEndpoint } = require("../environment"); +const { + apiTokenSecret, + dcApiEndpoint, + devTeamNetIds, +} = require("../environment"); const jwt = require("jsonwebtoken"); function emptyToken() { @@ -35,8 +39,8 @@ class ApiToken { email: user?.mail, isLoggedIn: !!user, primaryAffiliation: user?.primaryAffiliation, + isDevTeam: !!user && user?.uid && devTeamNetIds().includes(user?.uid), }; - return this.update(); } @@ -102,6 +106,10 @@ class ApiToken { return this.token.entitlements.has(entitlement); } + isDevTeam() { + return this.token.isDevTeam; + } + isLoggedIn() { return this.token.isLoggedIn; } diff --git a/node/src/environment.js b/node/src/environment.js index 44958569..cf330b1c 100644 --- a/node/src/environment.js +++ b/node/src/environment.js @@ -40,6 +40,10 @@ function dcUrl() { return process.env.DC_URL; } +function devTeamNetIds() { + return process.env.DEV_TEAM_NET_IDS.split(","); +} + function openSearchEndpoint() { return process.env.OPENSEARCH_ENDPOINT; } @@ -61,6 +65,7 @@ module.exports = { appInfo, dcApiEndpoint, dcUrl, + devTeamNetIds, openSearchEndpoint, prefix, region, diff --git a/node/test/test-helpers/index.js b/node/test/test-helpers/index.js index 7ae31a2f..da85f1b2 100644 --- a/node/test/test-helpers/index.js +++ b/node/test/test-helpers/index.js @@ -11,6 +11,7 @@ const TestEnvironment = { API_TOKEN_NAME: "dcapiTEST", DC_URL: "https://thisisafakedcurl", DC_API_ENDPOINT: "https://thisisafakeapiurl", + DEV_TEAM_NET_IDS: "abc123,def456", NUSSO_BASE_URL: "https://nusso-base.com/", NUSSO_API_KEY: "abc123", WEBSOCKET_URI: "wss://thisisafakewebsocketapiurl", diff --git a/node/test/unit/api/api-token.test.js b/node/test/unit/api/api-token.test.js index e8cd2d6b..b18fcca2 100644 --- a/node/test/unit/api/api-token.test.js +++ b/node/test/unit/api/api-token.test.js @@ -74,6 +74,21 @@ describe("ApiToken", function () { }); }); + describe("isDevTeam", function () { + it("sets the isDevTeam flag to true", async () => { + const user = { + uid: "abc123", + displayName: ["A. Developer"], + mail: "user@example.com", + }; + const token = new ApiToken(); + token.user(user); + + expect(token.isDevTeam()).to.be.true; + expect(token.isLoggedIn()).to.be.true; + }); + }); + describe("entitlements", function () { it("addEntitlement() adds an entitlement to the token", async () => { const payload = { diff --git a/template.yaml b/template.yaml index 67fd7ad6..9801c329 100644 --- a/template.yaml +++ b/template.yaml @@ -26,6 +26,7 @@ Globals: API_TOKEN_SECRET: !Ref ApiTokenSecret DC_API_ENDPOINT: !Ref DcApiEndpoint DC_URL: !Ref DcUrl + DEV_TEAM_NET_IDS: !Ref DevTeamNetIds OPENSEARCH_ENDPOINT: !Ref OpenSearchEndpoint ENV_PREFIX: !Ref EnvironmentPrefix HONEYBADGER_API_KEY: !Ref HoneybadgerApiKey @@ -63,6 +64,9 @@ Parameters: DcUrl: Type: String Description: URL of Digital Collections website + DevTeamNetIds: + Type: String + Description: Northwestern NetIDs of the development team FfmpegLayer: Type: String Description: "FFMPEG Lambda Layer ARN"