Skip to content

Commit

Permalink
TP: added more unit tests
Browse files Browse the repository at this point in the history
Signed-off-by: Chris Helma <chelma+github@amazon.com>
  • Loading branch information
chelma committed Dec 10, 2024
1 parent babea3b commit fc36eb2
Show file tree
Hide file tree
Showing 4 changed files with 231 additions and 4 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
from django.test import TestCase
from unittest.mock import patch, MagicMock
from requests import HTTPError, ConnectionError
from transform_expert.utils.opensearch_client import OpenSearchClient
from transform_expert.utils.rest_client import RESTClient, ConnectionDetails


class OpenSearchClientTestCase(TestCase):
def setUp(self):
# Initialize RESTClient and OpenSearchClient
self.connection_details = ConnectionDetails(base_url="http://opensearch.example.com")
self.rest_client = RESTClient(connection_details=self.connection_details)
self.os_client = OpenSearchClient(rest_client=self.rest_client)

@patch("transform_expert.utils.rest_client.RESTClient.get")
def test_is_accessible_happy_path(self, mock_get):
# Mock a successful GET request
mock_get.return_value = {}

result = self.os_client.is_accessible()

# Assertions
mock_get.assert_called_once_with("")
self.assertTrue(result)

@patch("transform_expert.utils.rest_client.RESTClient.get")
def test_is_accessible_error_path(self, mock_get):
# Mock HTTPError
mock_get.side_effect = HTTPError("Not Found")

result = self.os_client.is_accessible()

# Assertions
mock_get.assert_called_once_with("")
self.assertFalse(result)

# Mock ConnectionError
mock_get.reset_mock()
mock_get.side_effect = ConnectionError("Connection failed")

result = self.os_client.is_accessible()

# Assertions
mock_get.assert_called_once_with("")
self.assertFalse(result)

@patch("transform_expert.utils.rest_client.RESTClient.put")
def test_create_index_happy_path(self, mock_put):
# Mock a successful PUT request
mock_response = {"acknowledged": True}
mock_put.return_value = mock_response

index_name = "test-index"
settings = {"settings": {"number_of_shards": 1}}
result = self.os_client.create_index(index_name, settings)

# Assertions
mock_put.assert_called_once_with("test-index", data=settings)
self.assertEqual(result, mock_response)

@patch("transform_expert.utils.rest_client.RESTClient.get")
def test_describe_index_happy_path(self, mock_get):
# Mock a successful GET request
mock_response = {"index": "test-index", "settings": {}}
mock_get.return_value = mock_response

index_name = "test-index"
result = self.os_client.describe_index(index_name)

# Assertions
mock_get.assert_called_once_with("test-index")
self.assertEqual(result, mock_response)

@patch("transform_expert.utils.rest_client.RESTClient.put")
def test_update_index_happy_path(self, mock_put):
# Mock a successful PUT request
mock_response = {"acknowledged": True}
mock_put.return_value = mock_response

index_name = "test-index"
settings = {"settings": {"number_of_replicas": 2}}
result = self.os_client.update_index(index_name, settings)

# Assertions
mock_put.assert_called_once_with("test-index/_settings", data=settings)
self.assertEqual(result, mock_response)

@patch("transform_expert.utils.rest_client.RESTClient.delete")
def test_delete_index_happy_path(self, mock_delete):
# Mock a successful DELETE request
mock_response = {"acknowledged": True}
mock_delete.return_value = mock_response

index_name = "test-index"
result = self.os_client.delete_index(index_name)

# Assertions
mock_delete.assert_called_once_with("test-index")
self.assertEqual(result, mock_response)
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
import requests

from django.test import TestCase
from unittest.mock import patch, MagicMock

from transform_expert.utils.rest_client import RESTClient, ConnectionDetails


class RESTClientTestCase(TestCase):
def setUp(self):
self.connection_details = ConnectionDetails(base_url="http://api.example.com")
self.client = RESTClient(connection_details=self.connection_details)

@patch("requests.get")
def test_get_happy_path(self, mock_get):
# Mock the GET response
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = {"key": "value"}
mock_get.return_value = mock_response

response = self.client.get("test-endpoint")

# Assertions
mock_get.assert_called_once_with("http://api.example.com/test-endpoint")
self.assertEqual(response, {"key": "value"})

@patch("requests.get")
def test_get_error_path(self, mock_get):
# Mock a GET response with an error
mock_response = MagicMock()
mock_response.raise_for_status.side_effect = requests.HTTPError("Not Found")
mock_get.return_value = mock_response

with self.assertRaises(requests.HTTPError):
self.client.get("test-endpoint")

# Assertions
mock_get.assert_called_once_with("http://api.example.com/test-endpoint")

@patch("requests.put")
def test_put_happy_path(self, mock_put):
# Mock the PUT response
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = {"key": "value"}
mock_put.return_value = mock_response

response = self.client.put("test-endpoint", data={"data_key": "data_value"})

# Assertions
mock_put.assert_called_once_with(
"http://api.example.com/test-endpoint", json={"data_key": "data_value"}
)
self.assertEqual(response, {"key": "value"})

@patch("requests.put")
def test_put_error_path(self, mock_put):
# Mock a PUT response with an error
mock_response = MagicMock()
mock_response.raise_for_status.side_effect = requests.HTTPError("Server Error")
mock_put.return_value = mock_response

with self.assertRaises(requests.HTTPError):
self.client.put("test-endpoint", data={"data_key": "data_value"})

# Assertions
mock_put.assert_called_once_with(
"http://api.example.com/test-endpoint", json={"data_key": "data_value"}
)

@patch("requests.post")
def test_post_happy_path(self, mock_post):
# Mock the POST response
mock_response = MagicMock()
mock_response.status_code = 201
mock_response.json.return_value = {"key": "value"}
mock_post.return_value = mock_response

response = self.client.post("test-endpoint", data={"data_key": "data_value"})

# Assertions
mock_post.assert_called_once_with(
"http://api.example.com/test-endpoint", json={"data_key": "data_value"}
)
self.assertEqual(response, {"key": "value"})

@patch("requests.post")
def test_post_error_path(self, mock_post):
# Mock a POST response with an error
mock_response = MagicMock()
mock_response.raise_for_status.side_effect = requests.HTTPError("Bad Request")
mock_post.return_value = mock_response

with self.assertRaises(requests.HTTPError):
self.client.post("test-endpoint", data={"data_key": "data_value"})

# Assertions
mock_post.assert_called_once_with(
"http://api.example.com/test-endpoint", json={"data_key": "data_value"}
)

@patch("requests.delete")
def test_delete_happy_path(self, mock_delete):
# Mock the DELETE response
mock_response = MagicMock()
mock_response.status_code = 204
mock_response.json.return_value = {}
mock_delete.return_value = mock_response

response = self.client.delete("test-endpoint")

# Assertions
mock_delete.assert_called_once_with("http://api.example.com/test-endpoint")
self.assertEqual(response, {})

@patch("requests.delete")
def test_delete_error_path(self, mock_delete):
# Mock a DELETE response with an error
mock_response = MagicMock()
mock_response.raise_for_status.side_effect = requests.HTTPError("Unauthorized")
mock_delete.return_value = mock_response

with self.assertRaises(requests.HTTPError):
self.client.delete("test-endpoint")

# Assertions
mock_delete.assert_called_once_with("http://api.example.com/test-endpoint")
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import logging
from requests import HTTPError, ConnectionError
from typing import Optional, Dict, Any

from transform_expert.utils.rest_client import RESTClient

logger = logging.getLogger(__name__)
logger = logging.getLogger("transform_expert")

class OpenSearchClient():
rest_client: RESTClient
Expand All @@ -19,7 +20,7 @@ def is_accessible(self) -> bool:
try:
self.rest_client.get("")
return True
except Exception as e:
except (HTTPError, ConnectionError) as e:
logger.error(f"OpenSearch Cluster is not accessible: {str(e)}")
return False

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,12 @@
import logging
from typing import Optional, Dict, Any

logger = logging.getLogger(__name__)
logger = logging.getLogger("transform_expert")

@dataclass
class ConnectionDetails:
base_url: str

# Raw REST client responsible for making HTTP requests to the OpenSearch cluster
class RESTClient():
def __init__(self, connection_details: ConnectionDetails) -> None:
self.base_url = connection_details.base_url.rstrip('/')
Expand Down

0 comments on commit fc36eb2

Please sign in to comment.