diff --git a/sdk/python/kubeflow/training/api/training_client_test.py b/sdk/python/kubeflow/training/api/training_client_test.py index 8d0ab6310d..344ee76e2e 100644 --- a/sdk/python/kubeflow/training/api/training_client_test.py +++ b/sdk/python/kubeflow/training/api/training_client_test.py @@ -4,6 +4,8 @@ import pytest from kubeflow.training import ( + KubeflowOrgV1JobCondition, + KubeflowOrgV1JobStatus, KubeflowOrgV1PyTorchJob, KubeflowOrgV1PyTorchJobSpec, KubeflowOrgV1ReplicaSpec, @@ -70,6 +72,13 @@ def get(self, timeout): return MockResponse() +def get_job_response(*args, **kwargs): + if kwargs.get("namespace") == RUNTIME: + return generate_job_with_status(create_job(), constants.JOB_CONDITION_FAILED) + else: + return generate_job_with_status(create_job()) + + def generate_container() -> V1Container: return V1Container( name="pytorch", @@ -127,6 +136,21 @@ def create_job(): return pytorchjob +def generate_job_with_status( + job: constants.JOB_MODELS_TYPE, + condition_type: str = constants.JOB_CONDITION_SUCCEEDED, +) -> constants.JOB_MODELS_TYPE: + job.status = KubeflowOrgV1JobStatus( + conditions=[ + KubeflowOrgV1JobCondition( + type=condition_type, + status=constants.CONDITION_STATUS_TRUE, + ) + ] + ) + return job + + class DummyJobClass: def __init__(self, kind) -> None: self.kind = kind @@ -279,6 +303,61 @@ def __init__(self, kind) -> None: ), ] +test_data_wait_for_job_conditions = [ + ( + "timeout waiting for succeeded condition", + { + "name": TEST_NAME, + "namespace": TIMEOUT, + "wait_timeout": 0, + }, + TimeoutError, + ), + ( + "invalid expected condition", + { + "name": TEST_NAME, + "namespace": "value", + "expected_conditions": {"invalid"}, + }, + ValueError, + ), + ( + "invalid expected condition(lowercase)", + { + "name": TEST_NAME, + "namespace": "value", + "expected_conditions": {"succeeded"}, + }, + ValueError, + ), + ( + "job failed unexpectedly", + { + "name": TEST_NAME, + "namespace": RUNTIME, + }, + RuntimeError, + ), + ( + "valid case", + { + "name": TEST_NAME, + "namespace": "test-namespace", + }, + generate_job_with_status(create_job()), + ), + ( + "valid case with specified callback", + { + "name": TEST_NAME, + "namespace": "test-namespace", + "callback": lambda job: "test train function", + }, + generate_job_with_status(create_job()), + ), +] + test_data_get_job_pod_names = [ ( @@ -354,6 +433,8 @@ def training_client(): ), ), patch( "kubernetes.config.load_kube_config", return_value=Mock() + ), patch.object( + TrainingClient, "get_job", side_effect=get_job_response ): client = TrainingClient(job_kind=constants.PYTORCHJOB_KIND) yield client @@ -434,3 +515,19 @@ def test_update_job(training_client, test_name, kwargs, expected_output): except Exception as e: assert type(e) is expected_output print("test execution complete") + + +@pytest.mark.parametrize( + "test_name,kwargs,expected_output", test_data_wait_for_job_conditions +) +def test_wait_for_job_conditions(training_client, test_name, kwargs, expected_output): + """ + test wait_for_job_conditions function of training client + """ + print("Executing test:", test_name) + try: + out = training_client.wait_for_job_conditions(**kwargs) + assert out == expected_output + except Exception as e: + assert type(e) is expected_output + print("test execution complete")