Skip to content

Commit

Permalink
TP: Minor refactoring, added more unit tests
Browse files Browse the repository at this point in the history
Signed-off-by: Chris Helma <chelma+github@amazon.com>
  • Loading branch information
chelma committed Dec 10, 2024
1 parent fc36eb2 commit d76afc1
Show file tree
Hide file tree
Showing 6 changed files with 213 additions and 67 deletions.
157 changes: 157 additions & 0 deletions TransformationPlayground/tp_backend/transform_api/tests/test_views.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
from unittest.mock import patch, MagicMock
from django.test import TestCase
from rest_framework.test import APIClient
from rest_framework import status

from transform_expert.validation import TestTargetInnaccessibleError


class TransformsIndexViewTestCase(TestCase):
def setUp(self):
self.client = APIClient()
self.url = "/transforms/index/"

self.valid_request_body = {
"transform_language": "Python",
"source_version": "Elasticsearch 6.8",
"target_version": "OpenSearch 2.17",
"input_shape": {
"index_name": "test-index",
"index_json": {
"settings": {
"index": {
"number_of_shards": 1,
"number_of_replicas": 0
}
},
"mappings": {
"type1": {
"properties": {
"title": {"type": "text"}
}
},
"type2": {
"properties": {
"contents": {"type": "text"}
}
}
}
}
},
"test_target_url": "http://localhost:29200"
}

self.valid_response_body = {
"output_shape": [
{
"index_name": "test-index-type1",
"index_json": {
"settings": {
"index": {
"number_of_shards": 1,
"number_of_replicas": 0
}
},
"mappings": {
"properties": {
"title": {"type": "text"}
}
}
}
},
{
"index_name": "test-index-type2",
"index_json": {
"settings": {
"index": {
"number_of_shards": 1,
"number_of_replicas": 0
}
},
"mappings": {
"properties": {
"contents": {"type": "text"}
}
}
}
}
],
"transform_logic": "Generated Python transformation logic"
}

@patch("transform_api.views.TransformsIndexView._perform_transformation")
def test_post_happy_path(self, mock_perform_transformation):
# Mock the transformation result
mock_transform_report = MagicMock()
mock_transform_report.task.output = self.valid_response_body["output_shape"]
mock_transform_report.task.transform.to_file_format.return_value = self.valid_response_body["transform_logic"]
mock_perform_transformation.return_value = mock_transform_report

# Make the request
response = self.client.post(self.url, self.valid_request_body, format="json")

# Assertions
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.json(), self.valid_response_body)
mock_perform_transformation.assert_called_once()

def test_post_invalid_request_body(self):
# Incomplete request body
invalid_request_body = {"transform_language": "Python"}

# Make the request
response = self.client.post(self.url, invalid_request_body, format="json")

# Assertions
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertIn("input_shape", response.json())

@patch("transform_api.views.TransformsIndexView._perform_transformation")
def test_post_inaccessible_target_cluster(self, mock_perform_transformation):
# Mock the `_perform_transformation` method to raise `TestTargetInnaccessibleError`
mock_perform_transformation.side_effect = TestTargetInnaccessibleError("Cluster not accessible")

# Make the request
response = self.client.post(self.url, self.valid_request_body, format="json")

# Assertions
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertEqual(response.json(), {"error": "Cluster not accessible"})
mock_perform_transformation.assert_called_once()

@patch("transform_api.views.TransformsIndexView._perform_transformation")
def test_post_general_transformation_failure(self, mock_perform_transformation):
# Mock the `_perform_transformation` method to raise a general exception
mock_perform_transformation.side_effect = RuntimeError("General failure")

# Make the request
response = self.client.post(self.url, self.valid_request_body, format="json")

# Assertions
self.assertEqual(response.status_code, status.HTTP_500_INTERNAL_SERVER_ERROR)
self.assertEqual(response.json(), {"error": "General failure"})
mock_perform_transformation.assert_called_once()

@patch("transform_api.views.TransformsIndexCreateResponseSerializer")
@patch("transform_api.views.TransformsIndexView._perform_transformation")
def test_post_invalid_response(self, mock_perform_transformation, mock_response_serializer):
# Mock the transformation result
mock_transform_report = MagicMock()
mock_transform_report.task.output = self.valid_response_body["output_shape"]
mock_transform_report.task.transform.to_file_format.return_value = self.valid_response_body["transform_logic"]
mock_perform_transformation.return_value = mock_transform_report

# Mock the serializer behavior
mock_serializer_instance = mock_response_serializer.return_value
mock_serializer_instance.is_valid.return_value = False
mock_serializer_instance.errors = {"transform_logic": ["Invalid format"]}

# Make the request
response = self.client.post(self.url, self.valid_request_body, format="json")

# Assertions
self.assertEqual(response.status_code, status.HTTP_500_INTERNAL_SERVER_ERROR)
mock_perform_transformation.assert_called_once()
mock_serializer_instance.is_valid.assert_called_once()
self.assertEqual(mock_serializer_instance.errors, {"transform_logic": ["Invalid format"]})

14 changes: 6 additions & 8 deletions TransformationPlayground/tp_backend/transform_api/views.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,12 @@
import logging
import uuid

from langchain_core.messages import HumanMessage
from rest_framework.views import APIView
from rest_framework.response import Response
from rest_framework import status
from .serializers import TransformsIndexCreateRequestSerializer, TransformsIndexCreateResponseSerializer


import uuid

from langchain_core.messages import HumanMessage

from .serializers import TransformsIndexCreateRequestSerializer, TransformsIndexCreateResponseSerializer
from transform_expert.expert import get_expert, invoke_expert
from transform_expert.parameters import TransformType
from transform_expert.validation import test_target_connection, TestTargetInnaccessibleError, IndexTransformValidator, ValidationReport
Expand All @@ -20,6 +17,7 @@

logger = logging.getLogger("transform_api")


class TransformsIndexView(APIView):
def post(self, request):
logger.info(f"Received transformation request: {request.data}")
Expand Down Expand Up @@ -94,7 +92,7 @@ def _perform_transformation(self, transform_id: str, request: TransformsIndexCre

transform_result = invoke_expert(expert, transform_task)

# Execute the transformation on the input
transform_test_report = IndexTransformValidator(transform_task, test_connection).validate()
# Execute the transformation on the input and test it against the target cluster
transform_test_report = IndexTransformValidator(transform_result, test_connection).validate()

return transform_test_report
54 changes: 49 additions & 5 deletions TransformationPlayground/tp_backend/transform_expert/expert.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
import asyncio
from dataclasses import dataclass
import logging
from typing import Any, Callable, Dict
from typing import Any, Callable, Dict, List

from botocore.config import Config
from langchain_aws import ChatBedrockConverse
from langchain_core.language_models import LanguageModelInput
from langchain_core.messages import BaseMessage, SystemMessage, ToolMessage
from langchain_core.runnables import Runnable


from transform_expert.utils.inference import perform_inference
from transform_expert.parameters import SourceVersion, TargetVersion, TransformType, TransformLanguage
from transform_expert.prompting import get_system_prompt_factory
from transform_expert.tools import ToolBundle, get_tool_bundle
Expand Down Expand Up @@ -63,11 +62,11 @@ def get_expert(source_version: SourceVersion, target_version: TargetVersion, tra

def invoke_expert(expert: Expert, task: TransformTask) -> TransformTask:
logger.info(f"Invoking the Transform Expert for transform_id: {task.transform_id}")
logger.debug(f"Transform Task: {str(task.to_json())}")
logger.debug(f"Initial Transform Task: {str(task.to_json())}")

# Invoke the LLM. This should result in the LLM making a tool call, forcing it to create the transform details by
# conforming to the tool's schema.
inference_task = task.to_inference_task()
inference_task = InferenceTask.from_transform_task(task)
inference_result = perform_inference(expert.llm, [inference_task])[0]

logger.debug(f"Inference Result: {str(inference_result.to_json())}")
Expand All @@ -93,3 +92,48 @@ def invoke_expert(expert: Expert, task: TransformTask) -> TransformTask:
logger.debug(f"Updated Transform Task: {str(task.to_json())}")

return task


@dataclass
class InferenceTask:
transform_id: str
context: List[BaseMessage]

@staticmethod
def from_transform_task(task: TransformTask) -> 'InferenceTask':
return InferenceTask(
transform_id=task.transform_id,
context=task.context
)

def to_json(self) -> dict:
return {
"transform_id": self.transform_id,
"context": [turn.to_json() for turn in self.context]
}

@dataclass
class InferenceResult:
transform_id: str
response: BaseMessage

def to_json(self) -> dict:
return {
"transform_id": self.transform_id,
"response": self.response.to_json()
}


def perform_inference(llm: Runnable[LanguageModelInput, BaseMessage], batched_tasks: List[InferenceTask]) -> List[InferenceResult]:
return asyncio.run(_perform_async_inference(llm, batched_tasks))

# Inference APIs can be throttled pretty aggressively. Performing them as a batch operation can help with increasing
# throughput. Ideally, we'd be using Bedrock's batch inference API, but Bedrock's approach to that is an asynchronous
# process that writes the results to S3 and returns a URL to the results. This is not implemented by default in the
# ChatBedrockConverse class, so we'll skip true batch processing for now. Instead, we'll just perform the inferences in
# parallel with aggressive retry logic.
async def _perform_async_inference(llm: Runnable[LanguageModelInput, BaseMessage], batched_tasks: List[InferenceTask]) -> List[InferenceResult]:
async_responses = [llm.ainvoke(task.context) for task in batched_tasks]
responses = await asyncio.gather(*async_responses)

return [InferenceResult(transform_id=task.transform_id, response=response) for task, response in zip(batched_tasks, responses)]
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from django.test import TestCase
from unittest.mock import patch, MagicMock
from unittest.mock import patch
from requests import HTTPError, ConnectionError
from transform_expert.utils.opensearch_client import OpenSearchClient
from transform_expert.utils.rest_client import RESTClient, ConnectionDetails
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@

from langchain_core.messages import BaseMessage

from transform_expert.utils.inference import InferenceTask


logger = logging.getLogger("transform_expert")

Expand Down Expand Up @@ -53,12 +51,6 @@ def to_json(self) -> Dict[str, Any]:
"transform": self.transform.to_json() if self.transform else None,
"output": self.output if self.output else None
}

def to_inference_task(self) -> InferenceTask:
return InferenceTask(
transform_id=self.transform_id,
context=self.context
)

class TransformInvalidSyntaxError(Exception):
pass
Expand Down

0 comments on commit d76afc1

Please sign in to comment.