Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SDK] Add UTs for wait_for_job_conditions #2196

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 97 additions & 0 deletions sdk/python/kubeflow/training/api/training_client_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

import pytest
from kubeflow.training import (
KubeflowOrgV1JobCondition,
KubeflowOrgV1JobStatus,
KubeflowOrgV1PyTorchJob,
KubeflowOrgV1PyTorchJobSpec,
KubeflowOrgV1ReplicaSpec,
Expand Down Expand Up @@ -70,6 +72,13 @@ def get(self, timeout):
return MockResponse()


def get_job_response(*args, **kwargs):
if kwargs.get("namespace") == RUNTIME:
return generate_job_with_status(create_job(), constants.JOB_CONDITION_FAILED)
else:
return generate_job_with_status(create_job())


def generate_container() -> V1Container:
return V1Container(
name="pytorch",
Expand Down Expand Up @@ -127,6 +136,21 @@ def create_job():
return pytorchjob


def generate_job_with_status(
job: constants.JOB_MODELS_TYPE,
condition_type: str = constants.JOB_CONDITION_SUCCEEDED,
) -> constants.JOB_MODELS_TYPE:
job.status = KubeflowOrgV1JobStatus(
conditions=[
KubeflowOrgV1JobCondition(
type=condition_type,
status=constants.CONDITION_STATUS_TRUE,
)
]
)
return job


class DummyJobClass:
def __init__(self, kind) -> None:
self.kind = kind
Expand Down Expand Up @@ -279,6 +303,61 @@ def __init__(self, kind) -> None:
),
]

test_data_wait_for_job_conditions = [
(
"timeout waiting for succeeded condition",
{
"name": TEST_NAME,
"namespace": TIMEOUT,
"wait_timeout": 0,
},
TimeoutError,
),
(
"invalid expected condition",
{
"name": TEST_NAME,
"namespace": "value",
"expected_conditions": {"invalid"},
},
ValueError,
),
(
"invalid expected condition(lowercase)",
{
"name": TEST_NAME,
"namespace": "value",
"expected_conditions": {"succeeded"},
},
ValueError,
),
(
"job failed unexpectedly",
{
"name": TEST_NAME,
"namespace": RUNTIME,
},
RuntimeError,
),
(
"valid case",
{
"name": TEST_NAME,
"namespace": "test-namespace",
},
generate_job_with_status(create_job()),
),
(
"valid case with specified callback",
{
"name": TEST_NAME,
"namespace": "test-namespace",
"callback": lambda job: "test train function",
},
generate_job_with_status(create_job()),
),
]


test_data_get_job_pod_names = [
(
Expand Down Expand Up @@ -354,6 +433,8 @@ def training_client():
),
), patch(
"kubernetes.config.load_kube_config", return_value=Mock()
), patch.object(
TrainingClient, "get_job", side_effect=get_job_response
):
client = TrainingClient(job_kind=constants.PYTORCHJOB_KIND)
yield client
Expand Down Expand Up @@ -434,3 +515,19 @@ def test_update_job(training_client, test_name, kwargs, expected_output):
except Exception as e:
assert type(e) is expected_output
print("test execution complete")


@pytest.mark.parametrize(
"test_name,kwargs,expected_output", test_data_wait_for_job_conditions
)
def test_wait_for_job_conditions(training_client, test_name, kwargs, expected_output):
"""
test wait_for_job_conditions function of training client
"""
print("Executing test:", test_name)
try:
out = training_client.wait_for_job_conditions(**kwargs)
assert out == expected_output
except Exception as e:
assert type(e) is expected_output
print("test execution complete")
Loading