-
Notifications
You must be signed in to change notification settings - Fork 1.6k
/
sdk_execution_tests.py
129 lines (105 loc) · 4.51 KB
/
sdk_execution_tests.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
# Copyright 2022 The Kubeflow Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import dataclasses
import functools
import os
import sys
from typing import Any, Dict, List, Tuple
from kfp import client
from kfp import dsl # noqa
import kfp_server_api
import pytest
import yaml
KFP_ENDPOINT = os.environ['KFP_ENDPOINT']
TIMEOUT_SECONDS = os.environ['TIMEOUT_SECONDS']
CURRENT_DIR = os.path.abspath(os.path.dirname(__file__))
PROJECT_ROOT = os.path.abspath(
os.path.join(CURRENT_DIR, *([os.path.pardir] * 2)))
CONFIG_PATH = os.path.join(PROJECT_ROOT, 'sdk', 'python', 'test_data',
'test_data_config.yaml')
kfp_client = client.Client(host=KFP_ENDPOINT)
@dataclasses.dataclass
class TestCase:
name: str
module_path: str
yaml_path: str
function_name: str
arguments: Dict[str, Any]
def create_test_case_parameters() -> List[TestCase]:
parameters: List[TestCase] = []
with open(CONFIG_PATH) as f:
config = yaml.safe_load(f)
for name, test_group in config.items():
test_data_dir = os.path.join(PROJECT_ROOT, test_group['test_data_dir'])
parameters.extend(
TestCase(
name=name + '-' + test_case['module'],
module_path=os.path.join(test_data_dir,
f'{test_case["module"]}.py'),
yaml_path=os.path.join(test_data_dir,
f'{test_case["module"]}.yaml'),
function_name=test_case['name'],
arguments=test_case.get('arguments'),
) for test_case in test_group['test_cases'] if test_case['execute'])
return parameters
def wait(
run_result: client.client.RunPipelineResult
) -> kfp_server_api.V2beta1Run:
return kfp_client.wait_for_run_completion(
run_id=run_result.run_id, timeout=int(TIMEOUT_SECONDS))
def import_obj_from_file(python_path: str, obj_name: str) -> Any:
sys.path.insert(0, os.path.dirname(python_path))
module_name = os.path.splitext(os.path.split(python_path)[1])[0]
module = __import__(module_name, fromlist=[obj_name])
if not hasattr(module, obj_name):
raise ValueError(
f'Object "{obj_name}" not found in module {python_path}.')
return getattr(module, obj_name)
def run(test_case: TestCase) -> Tuple[str, client.client.RunPipelineResult]:
full_path = os.path.join(PROJECT_ROOT, test_case.module_path)
pipeline_func = import_obj_from_file(full_path, test_case.function_name)
run_result = kfp_client.create_run_from_pipeline_func(
pipeline_func,
enable_caching=True,
arguments=test_case.arguments,
)
run_url = f'{KFP_ENDPOINT}/#/runs/details/{run_result.run_id}'
print(
f'- Created run {test_case.name}\n\tModule: {test_case.module_path}\n\tURL: {run_url}\n'
)
return run_url, run_result
def get_kfp_package_path() -> str:
if os.environ.get('PULL_NUMBER') is not None:
path = f'git+https://github.com/kubeflow/pipelines.git@refs/pull/{os.environ.get("PULL_NUMBER")}/merge#subdirectory=sdk/python'
else:
path = 'git+https://github.com/kubeflow/pipelines.git@master#subdirectory=sdk/python'
print(f'Using the following KFP package path for tests: {path}')
return path
dsl.component = functools.partial(
dsl.component, kfp_package_path=get_kfp_package_path())
@pytest.mark.asyncio_cooperative
@pytest.mark.parametrize('test_case', create_test_case_parameters())
async def test(test_case: TestCase) -> None:
"""Asynchronously runs all samples and test that they succeed."""
event_loop = asyncio.get_running_loop()
try:
run_url, run_result = run(test_case)
except Exception as e:
raise RuntimeError(
f'Error triggering pipeline {test_case.name}.') from e
api_run = await event_loop.run_in_executor(None, wait, run_result)
assert api_run.state == 'SUCCEEDED', f'Pipeline {test_case.name} ended with incorrect status: {api_run.state}. More info: {run_url}'
if __name__ == '__main__':
pytest.main()