From d76afc1aabdf436238b2fb1d2868e626e8bb7df9 Mon Sep 17 00:00:00 2001 From: Chris Helma Date: Tue, 10 Dec 2024 12:55:03 -0600 Subject: [PATCH] TP: Minor refactoring, added more unit tests Signed-off-by: Chris Helma --- .../transform_api/tests/test_views.py | 157 ++++++++++++++++++ .../tp_backend/transform_api/views.py | 14 +- .../tp_backend/transform_expert/expert.py | 54 +++++- .../tests/utils/test_opensearch_client.py | 2 +- .../transform_expert/utils/inference.py | 45 ----- .../transform_expert/utils/transforms.py | 8 - 6 files changed, 213 insertions(+), 67 deletions(-) create mode 100644 TransformationPlayground/tp_backend/transform_api/tests/test_views.py delete mode 100644 TransformationPlayground/tp_backend/transform_expert/utils/inference.py diff --git a/TransformationPlayground/tp_backend/transform_api/tests/test_views.py b/TransformationPlayground/tp_backend/transform_api/tests/test_views.py new file mode 100644 index 000000000..482abc723 --- /dev/null +++ b/TransformationPlayground/tp_backend/transform_api/tests/test_views.py @@ -0,0 +1,157 @@ +from unittest.mock import patch, MagicMock +from django.test import TestCase +from rest_framework.test import APIClient +from rest_framework import status + +from transform_expert.validation import TestTargetInnaccessibleError + + +class TransformsIndexViewTestCase(TestCase): + def setUp(self): + self.client = APIClient() + self.url = "/transforms/index/" + + self.valid_request_body = { + "transform_language": "Python", + "source_version": "Elasticsearch 6.8", + "target_version": "OpenSearch 2.17", + "input_shape": { + "index_name": "test-index", + "index_json": { + "settings": { + "index": { + "number_of_shards": 1, + "number_of_replicas": 0 + } + }, + "mappings": { + "type1": { + "properties": { + "title": {"type": "text"} + } + }, + "type2": { + "properties": { + "contents": {"type": "text"} + } + } + } + } + }, + "test_target_url": "http://localhost:29200" + } + + self.valid_response_body = { + "output_shape": [ + { + "index_name": "test-index-type1", + "index_json": { + "settings": { + "index": { + "number_of_shards": 1, + "number_of_replicas": 0 + } + }, + "mappings": { + "properties": { + "title": {"type": "text"} + } + } + } + }, + { + "index_name": "test-index-type2", + "index_json": { + "settings": { + "index": { + "number_of_shards": 1, + "number_of_replicas": 0 + } + }, + "mappings": { + "properties": { + "contents": {"type": "text"} + } + } + } + } + ], + "transform_logic": "Generated Python transformation logic" + } + + @patch("transform_api.views.TransformsIndexView._perform_transformation") + def test_post_happy_path(self, mock_perform_transformation): + # Mock the transformation result + mock_transform_report = MagicMock() + mock_transform_report.task.output = self.valid_response_body["output_shape"] + mock_transform_report.task.transform.to_file_format.return_value = self.valid_response_body["transform_logic"] + mock_perform_transformation.return_value = mock_transform_report + + # Make the request + response = self.client.post(self.url, self.valid_request_body, format="json") + + # Assertions + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(response.json(), self.valid_response_body) + mock_perform_transformation.assert_called_once() + + def test_post_invalid_request_body(self): + # Incomplete request body + invalid_request_body = {"transform_language": "Python"} + + # Make the request + response = self.client.post(self.url, invalid_request_body, format="json") + + # Assertions + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + self.assertIn("input_shape", response.json()) + + @patch("transform_api.views.TransformsIndexView._perform_transformation") + def test_post_inaccessible_target_cluster(self, mock_perform_transformation): + # Mock the `_perform_transformation` method to raise `TestTargetInnaccessibleError` + mock_perform_transformation.side_effect = TestTargetInnaccessibleError("Cluster not accessible") + + # Make the request + response = self.client.post(self.url, self.valid_request_body, format="json") + + # Assertions + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + self.assertEqual(response.json(), {"error": "Cluster not accessible"}) + mock_perform_transformation.assert_called_once() + + @patch("transform_api.views.TransformsIndexView._perform_transformation") + def test_post_general_transformation_failure(self, mock_perform_transformation): + # Mock the `_perform_transformation` method to raise a general exception + mock_perform_transformation.side_effect = RuntimeError("General failure") + + # Make the request + response = self.client.post(self.url, self.valid_request_body, format="json") + + # Assertions + self.assertEqual(response.status_code, status.HTTP_500_INTERNAL_SERVER_ERROR) + self.assertEqual(response.json(), {"error": "General failure"}) + mock_perform_transformation.assert_called_once() + + @patch("transform_api.views.TransformsIndexCreateResponseSerializer") + @patch("transform_api.views.TransformsIndexView._perform_transformation") + def test_post_invalid_response(self, mock_perform_transformation, mock_response_serializer): + # Mock the transformation result + mock_transform_report = MagicMock() + mock_transform_report.task.output = self.valid_response_body["output_shape"] + mock_transform_report.task.transform.to_file_format.return_value = self.valid_response_body["transform_logic"] + mock_perform_transformation.return_value = mock_transform_report + + # Mock the serializer behavior + mock_serializer_instance = mock_response_serializer.return_value + mock_serializer_instance.is_valid.return_value = False + mock_serializer_instance.errors = {"transform_logic": ["Invalid format"]} + + # Make the request + response = self.client.post(self.url, self.valid_request_body, format="json") + + # Assertions + self.assertEqual(response.status_code, status.HTTP_500_INTERNAL_SERVER_ERROR) + mock_perform_transformation.assert_called_once() + mock_serializer_instance.is_valid.assert_called_once() + self.assertEqual(mock_serializer_instance.errors, {"transform_logic": ["Invalid format"]}) + diff --git a/TransformationPlayground/tp_backend/transform_api/views.py b/TransformationPlayground/tp_backend/transform_api/views.py index 91cc18101..51d5f3c24 100644 --- a/TransformationPlayground/tp_backend/transform_api/views.py +++ b/TransformationPlayground/tp_backend/transform_api/views.py @@ -1,15 +1,12 @@ import logging +import uuid +from langchain_core.messages import HumanMessage from rest_framework.views import APIView from rest_framework.response import Response from rest_framework import status -from .serializers import TransformsIndexCreateRequestSerializer, TransformsIndexCreateResponseSerializer - - -import uuid - -from langchain_core.messages import HumanMessage +from .serializers import TransformsIndexCreateRequestSerializer, TransformsIndexCreateResponseSerializer from transform_expert.expert import get_expert, invoke_expert from transform_expert.parameters import TransformType from transform_expert.validation import test_target_connection, TestTargetInnaccessibleError, IndexTransformValidator, ValidationReport @@ -20,6 +17,7 @@ logger = logging.getLogger("transform_api") + class TransformsIndexView(APIView): def post(self, request): logger.info(f"Received transformation request: {request.data}") @@ -94,7 +92,7 @@ def _perform_transformation(self, transform_id: str, request: TransformsIndexCre transform_result = invoke_expert(expert, transform_task) - # Execute the transformation on the input - transform_test_report = IndexTransformValidator(transform_task, test_connection).validate() + # Execute the transformation on the input and test it against the target cluster + transform_test_report = IndexTransformValidator(transform_result, test_connection).validate() return transform_test_report diff --git a/TransformationPlayground/tp_backend/transform_expert/expert.py b/TransformationPlayground/tp_backend/transform_expert/expert.py index 523fdacd4..44f9e7034 100644 --- a/TransformationPlayground/tp_backend/transform_expert/expert.py +++ b/TransformationPlayground/tp_backend/transform_expert/expert.py @@ -1,6 +1,7 @@ +import asyncio from dataclasses import dataclass import logging -from typing import Any, Callable, Dict +from typing import Any, Callable, Dict, List from botocore.config import Config from langchain_aws import ChatBedrockConverse @@ -8,8 +9,6 @@ from langchain_core.messages import BaseMessage, SystemMessage, ToolMessage from langchain_core.runnables import Runnable - -from transform_expert.utils.inference import perform_inference from transform_expert.parameters import SourceVersion, TargetVersion, TransformType, TransformLanguage from transform_expert.prompting import get_system_prompt_factory from transform_expert.tools import ToolBundle, get_tool_bundle @@ -63,11 +62,11 @@ def get_expert(source_version: SourceVersion, target_version: TargetVersion, tra def invoke_expert(expert: Expert, task: TransformTask) -> TransformTask: logger.info(f"Invoking the Transform Expert for transform_id: {task.transform_id}") - logger.debug(f"Transform Task: {str(task.to_json())}") + logger.debug(f"Initial Transform Task: {str(task.to_json())}") # Invoke the LLM. This should result in the LLM making a tool call, forcing it to create the transform details by # conforming to the tool's schema. - inference_task = task.to_inference_task() + inference_task = InferenceTask.from_transform_task(task) inference_result = perform_inference(expert.llm, [inference_task])[0] logger.debug(f"Inference Result: {str(inference_result.to_json())}") @@ -93,3 +92,48 @@ def invoke_expert(expert: Expert, task: TransformTask) -> TransformTask: logger.debug(f"Updated Transform Task: {str(task.to_json())}") return task + + +@dataclass +class InferenceTask: + transform_id: str + context: List[BaseMessage] + + @staticmethod + def from_transform_task(task: TransformTask) -> 'InferenceTask': + return InferenceTask( + transform_id=task.transform_id, + context=task.context + ) + + def to_json(self) -> dict: + return { + "transform_id": self.transform_id, + "context": [turn.to_json() for turn in self.context] + } + +@dataclass +class InferenceResult: + transform_id: str + response: BaseMessage + + def to_json(self) -> dict: + return { + "transform_id": self.transform_id, + "response": self.response.to_json() + } + + +def perform_inference(llm: Runnable[LanguageModelInput, BaseMessage], batched_tasks: List[InferenceTask]) -> List[InferenceResult]: + return asyncio.run(_perform_async_inference(llm, batched_tasks)) + +# Inference APIs can be throttled pretty aggressively. Performing them as a batch operation can help with increasing +# throughput. Ideally, we'd be using Bedrock's batch inference API, but Bedrock's approach to that is an asynchronous +# process that writes the results to S3 and returns a URL to the results. This is not implemented by default in the +# ChatBedrockConverse class, so we'll skip true batch processing for now. Instead, we'll just perform the inferences in +# parallel with aggressive retry logic. +async def _perform_async_inference(llm: Runnable[LanguageModelInput, BaseMessage], batched_tasks: List[InferenceTask]) -> List[InferenceResult]: + async_responses = [llm.ainvoke(task.context) for task in batched_tasks] + responses = await asyncio.gather(*async_responses) + + return [InferenceResult(transform_id=task.transform_id, response=response) for task, response in zip(batched_tasks, responses)] diff --git a/TransformationPlayground/tp_backend/transform_expert/tests/utils/test_opensearch_client.py b/TransformationPlayground/tp_backend/transform_expert/tests/utils/test_opensearch_client.py index d5ad7f406..98c6914b8 100644 --- a/TransformationPlayground/tp_backend/transform_expert/tests/utils/test_opensearch_client.py +++ b/TransformationPlayground/tp_backend/transform_expert/tests/utils/test_opensearch_client.py @@ -1,5 +1,5 @@ from django.test import TestCase -from unittest.mock import patch, MagicMock +from unittest.mock import patch from requests import HTTPError, ConnectionError from transform_expert.utils.opensearch_client import OpenSearchClient from transform_expert.utils.rest_client import RESTClient, ConnectionDetails diff --git a/TransformationPlayground/tp_backend/transform_expert/utils/inference.py b/TransformationPlayground/tp_backend/transform_expert/utils/inference.py deleted file mode 100644 index 2f3078b58..000000000 --- a/TransformationPlayground/tp_backend/transform_expert/utils/inference.py +++ /dev/null @@ -1,45 +0,0 @@ -import asyncio -from dataclasses import dataclass -from typing import List - -from langchain_core.language_models import LanguageModelInput -from langchain_core.messages import BaseMessage -from langchain_core.runnables import Runnable - -@dataclass -class InferenceTask: - transform_id: str - context: List[BaseMessage] - - def to_json(self) -> dict: - return { - "transform_id": self.transform_id, - "context": [turn.to_json() for turn in self.context] - } - -@dataclass -class InferenceResult: - transform_id: str - response: BaseMessage - - def to_json(self) -> dict: - return { - "transform_id": self.transform_id, - "response": self.response.to_json() - } - - -def perform_inference(llm: Runnable[LanguageModelInput, BaseMessage], batched_tasks: List[InferenceTask]) -> List[InferenceResult]: - return asyncio.run(_perform_async_inference(llm, batched_tasks)) - -# Inference APIs can be throttled pretty aggressively. Performing them as a batch operation can help with increasing -# throughput. Ideally, we'd be using Bedrock's batch inference API, but Bedrock's approach to that is an asynchronous -# process that writes the results to S3 and returns a URL to the results. This is not implemented by default in the -# ChatBedrockConverse class, so we'll skip true batch processing for now. Instead, we'll just perform the inferences in -# parallel with aggressive retry logic. -async def _perform_async_inference(llm: Runnable[LanguageModelInput, BaseMessage], batched_tasks: List[InferenceTask]) -> List[InferenceResult]: - async_responses = [llm.ainvoke(task.context) for task in batched_tasks] - responses = await asyncio.gather(*async_responses) - - return [InferenceResult(transform_id=task.transform_id, response=response) for task, response in zip(batched_tasks, responses)] - \ No newline at end of file diff --git a/TransformationPlayground/tp_backend/transform_expert/utils/transforms.py b/TransformationPlayground/tp_backend/transform_expert/utils/transforms.py index 1b9dc7276..f699d8279 100644 --- a/TransformationPlayground/tp_backend/transform_expert/utils/transforms.py +++ b/TransformationPlayground/tp_backend/transform_expert/utils/transforms.py @@ -6,8 +6,6 @@ from langchain_core.messages import BaseMessage -from transform_expert.utils.inference import InferenceTask - logger = logging.getLogger("transform_expert") @@ -53,12 +51,6 @@ def to_json(self) -> Dict[str, Any]: "transform": self.transform.to_json() if self.transform else None, "output": self.output if self.output else None } - - def to_inference_task(self) -> InferenceTask: - return InferenceTask( - transform_id=self.transform_id, - context=self.context - ) class TransformInvalidSyntaxError(Exception): pass