From 00eef586956da5697514d159d3ea4fa4a86d3847 Mon Sep 17 00:00:00 2001 From: Mark Campbell Date: Wed, 4 Sep 2024 13:15:15 +0100 Subject: [PATCH] [SDK] test: add unit test for get_job method of the training_client (#2205) Signed-off-by: Bobbins228 --- .../training/api/training_client_test.py | 203 +++++++++++++----- 1 file changed, 147 insertions(+), 56 deletions(-) diff --git a/sdk/python/kubeflow/training/api/training_client_test.py b/sdk/python/kubeflow/training/api/training_client_test.py index a6a3a41b9f..f844a868c2 100644 --- a/sdk/python/kubeflow/training/api/training_client_test.py +++ b/sdk/python/kubeflow/training/api/training_client_test.py @@ -16,6 +16,7 @@ ) from kubeflow.training.models import V1DeleteOptions from kubernetes.client import ( + ApiClient, V1Container, V1ObjectMeta, V1PodSpec, @@ -42,6 +43,27 @@ def conditional_error_handler(*args, **kwargs): raise RuntimeError() +def serialize_k8s_object(obj): + api_client = ApiClient() + return api_client.sanitize_for_serialization(obj) + + +def get_namespaced_custom_object_response(*args, **kwargs): + if args[2] == "timeout": + raise multiprocessing.TimeoutError() + elif args[2] == "runtime": + raise RuntimeError() + + # Create a serialized Job + serialized_job = serialize_k8s_object(generate_job_with_status(create_job())) + + # Mock the thread and set it's return value to the serialized Job + mock_thread = Mock() + mock_thread.get.return_value = serialized_job + + return mock_thread + + def list_namespaced_pod_response(*args, **kwargs): class MockResponse: def get(self, timeout): @@ -419,6 +441,111 @@ def __init__(self, kind) -> None: ), ] +test_data_get_job = [ + ( + "valid flow with default namespace and default timeout", + {"name": TEST_NAME}, + SUCCESS, + ), + ( + "valid flow with all parameters set", + { + "name": TEST_NAME, + "namespace": TEST_NAME, + "job_kind": constants.PYTORCHJOB_KIND, + "timeout": 120, + }, + SUCCESS, + ), + ( + "invalid flow with default namespace and a Job that doesn't exist", + {"name": TEST_NAME, "job_kind": constants.TFJOB_KIND}, + RuntimeError, + ), + ( + "invalid flow incorrect parameter", + {"name": TEST_NAME, "test": "example"}, + TypeError, + ), + ( + "invalid flow withincorrect value", + {"name": TEST_NAME, "job_kind": "FailJob"}, + ValueError, + ), + ( + "runtime error case", + { + "name": TEST_NAME, + "namespace": "runtime", + "job_kind": constants.PYTORCHJOB_KIND, + }, + RuntimeError, + ), + ( + "invalid flow with timeout error", + {"name": TEST_NAME, "namespace": TIMEOUT}, + TimeoutError, + ), + ( + "invalid flow with runtime error", + {"name": TEST_NAME, "namespace": RUNTIME}, + RuntimeError, + ), +] + + +test_data_delete_job = [ + ( + "valid flow with default namespace", + { + "name": TEST_NAME, + }, + SUCCESS, + ), + ( + "invalid extra parameter", + {"name": TEST_NAME, "namespace": TEST_NAME, "example": "test"}, + TypeError, + ), + ( + "invalid job kind", + {"name": TEST_NAME, "job_kind": "invalid_job_kind"}, + RuntimeError, + ), + ( + "job name missing", + {"namespace": TEST_NAME, "job_kind": constants.PYTORCHJOB_KIND}, + TypeError, + ), + ( + "delete_namespaced_custom_object timeout error", + {"name": TEST_NAME, "namespace": TIMEOUT}, + TimeoutError, + ), + ( + "delete_namespaced_custom_object runtime error", + {"name": TEST_NAME, "namespace": RUNTIME}, + RuntimeError, + ), + ( + "valid flow", + { + "name": TEST_NAME, + "namespace": TEST_NAME, + "job_kind": constants.PYTORCHJOB_KIND, + }, + SUCCESS, + ), + ( + "valid flow with delete options", + { + "name": TEST_NAME, + "delete_options": V1DeleteOptions(grace_period_seconds=30), + }, + SUCCESS, + ), +] + @pytest.fixture def training_client(): @@ -428,6 +555,9 @@ def training_client(): create_namespaced_custom_object=Mock(side_effect=conditional_error_handler), patch_namespaced_custom_object=Mock(side_effect=conditional_error_handler), delete_namespaced_custom_object=Mock(side_effect=conditional_error_handler), + get_namespaced_custom_object=Mock( + side_effect=get_namespaced_custom_object_response + ), ), ), patch( "kubernetes.client.CoreV1Api", @@ -436,8 +566,6 @@ def training_client(): ), ), patch( "kubernetes.config.load_kube_config", return_value=Mock() - ), patch.object( - TrainingClient, "get_job", side_effect=get_job_response ): client = TrainingClient(job_kind=constants.PYTORCHJOB_KIND) yield client @@ -536,69 +664,32 @@ def test_wait_for_job_conditions(training_client, test_name, kwargs, expected_ou print("test execution complete") -test_data_delete_job = [ - ( - "valid flow with default namespace", - { - "name": TEST_NAME, - }, - SUCCESS, - ), - ( - "invalid extra parameter", - {"name": TEST_NAME, "namespace": TEST_NAME, "example": "test"}, - TypeError, - ), - ( - "invalid job kind", - {"name": TEST_NAME, "job_kind": "invalid_job_kind"}, - RuntimeError, - ), - ( - "job name missing", - {"namespace": TEST_NAME, "job_kind": constants.PYTORCHJOB_KIND}, - TypeError, - ), - ( - "delete_namespaced_custom_object timeout error", - {"name": TEST_NAME, "namespace": TIMEOUT}, - TimeoutError, - ), - ( - "delete_namespaced_custom_object runtime error", - {"name": TEST_NAME, "namespace": RUNTIME}, - RuntimeError, - ), - ( - "valid flow", - { - "name": TEST_NAME, - "namespace": TEST_NAME, - "job_kind": constants.PYTORCHJOB_KIND, - }, - SUCCESS, - ), - ( - "valid flow with delete options", - { - "name": TEST_NAME, - "delete_options": V1DeleteOptions(grace_period_seconds=30), - }, - SUCCESS, - ), -] - - @pytest.mark.parametrize("test_name,kwargs,expected_output", test_data_delete_job) def test_delete_job(training_client, test_name, kwargs, expected_output): """ test delete_job function of training client """ print("Executing test: ", test_name) - try: training_client.delete_job(**kwargs) assert expected_output == SUCCESS except Exception as e: assert type(e) is expected_output + + print("test execution complete") + + +@pytest.mark.parametrize("test_name,kwargs,expected_output", test_data_get_job) +def test_get_job(training_client, test_name, kwargs, expected_output): + """ + test get_job function of training client + """ + print("Executing test: ", test_name) + + try: + training_client.get_job(**kwargs) + assert expected_output == SUCCESS + except Exception as e: + assert type(e) is expected_output + print("test execution complete")