Skip to content

Commit

Permalink
feat(sdk): Support pipeline outputs (#8204)
Browse files Browse the repository at this point in the history
* Support pipeline outputs

* release note
  • Loading branch information
chensun authored Aug 29, 2022
1 parent 72c1d10 commit 48574dc
Show file tree
Hide file tree
Showing 9 changed files with 455 additions and 117 deletions.
1 change: 1 addition & 0 deletions sdk/RELEASE.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
## Major Features and Improvements
* Support parallelism setting in ParallelFor [\#8146](https://github.com/kubeflow/pipelines/pull/8146)
* Support for Python v3.10 [\#8186](https://github.com/kubeflow/pipelines/pull/8186)
* Support pipeline as a component [\#8179](https://github.com/kubeflow/pipelines/pull/8179), [\#8204](https://github.com/kubeflow/pipelines/pull/8204)

## Breaking Changes

Expand Down
1 change: 1 addition & 0 deletions sdk/python/kfp/compiler/_read_write_test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
'pipeline_with_parallelfor_parallelism',
'pipeline_in_pipeline',
'pipeline_in_pipeline_complex',
'pipeline_with_outputs',
],
'test_data_dir': 'sdk/python/kfp/compiler/test_data/pipelines',
'config': {
Expand Down
108 changes: 19 additions & 89 deletions sdk/python/kfp/compiler/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,11 +73,25 @@ def compile(

with type_utils.TypeCheckManager(enable=type_check):
if isinstance(pipeline_func, graph_component.GraphComponent):
pipeline_spec = self._create_pipeline(
pipeline_func=pipeline_func.pipeline_func,
pipeline_name=pipeline_name,
pipeline_parameters_override=pipeline_parameters,
)
# Retrieve the pre-comppiled pipeline spec.
pipeline_spec = pipeline_func.component_spec.implementation.graph

# Verify that pipeline_parameters contains only input names
# that match the pipeline inputs definition.
for input_name, input_value in (pipeline_parameters or
{}).items():
if input_name in pipeline_spec.root.input_definitions.parameters:
pipeline_spec.root.input_definitions.parameters[
input_name].default_value.CopyFrom(
builder.to_protobuf_value(input_value))
elif input_name in pipeline_spec.root.input_definitions.artifacts:
raise NotImplementedError(
'Default value for artifact input is not supported yet.'
)
else:
raise ValueError(
'Pipeline parameter {} does not match any known '
'pipeline input.'.format(input_name))

elif isinstance(pipeline_func, base_component.BaseComponent):
component_spec = builder.modify_component_spec_for_compile(
Expand All @@ -94,87 +108,3 @@ def compile(
f'decorator. Got: {type(pipeline_func)}')
builder.write_pipeline_spec_to_file(
pipeline_spec=pipeline_spec, package_path=package_path)

def _create_pipeline(
self,
pipeline_func: Callable[..., Any],
pipeline_name: Optional[str] = None,
pipeline_parameters_override: Optional[Mapping[str, Any]] = None,
) -> pipeline_spec_pb2.PipelineSpec:
"""Creates a pipeline instance and constructs the pipeline spec from
it.
Args:
pipeline_func: The pipeline function with @dsl.pipeline decorator.
pipeline_name: Optional; the name of the pipeline.
pipeline_parameters_override: Optional; the mapping from parameter
names to values.
Returns:
A PipelineSpec proto representing the compiled pipeline.
"""

# pipeline_func is a GraphComponent instance, retrieve its the original
# pipeline function
pipeline_func = getattr(pipeline_func, 'python_func', pipeline_func)

# Create the arg list with no default values and call pipeline function.
# Assign type information to the PipelineChannel
pipeline_meta = component_factory.extract_component_interface(
pipeline_func)
pipeline_name = pipeline_name or pipeline_meta.name

pipeline_root = getattr(pipeline_func, 'pipeline_root', None)

args_list = []
signature = inspect.signature(pipeline_func)

for arg_name in signature.parameters:
arg_type = pipeline_meta.inputs[arg_name].type
args_list.append(
pipeline_channel.create_pipeline_channel(
name=arg_name,
channel_type=arg_type,
))

with pipeline_context.Pipeline(pipeline_name) as dsl_pipeline:
pipeline_func(*args_list)

if not dsl_pipeline.tasks:
raise ValueError('Task is missing from pipeline.')

pipeline_inputs = pipeline_meta.inputs or {}

# Verify that pipeline_parameters_override contains only input names
# that match the pipeline inputs definition.
pipeline_parameters_override = pipeline_parameters_override or {}
for input_name in pipeline_parameters_override:
if input_name not in pipeline_inputs:
raise ValueError(
'Pipeline parameter {} does not match any known '
'pipeline argument.'.format(input_name))

# Fill in the default values.
args_list_with_defaults = [
pipeline_channel.create_pipeline_channel(
name=input_name,
channel_type=input_spec.type,
value=pipeline_parameters_override.get(input_name) or
input_spec.default,
) for input_name, input_spec in pipeline_inputs.items()
]

# Making the pipeline group name unique to prevent name clashes with
# templates
pipeline_group = dsl_pipeline.groups[0]
pipeline_group.name = uuid.uuid4().hex

pipeline_spec, _ = builder.create_pipeline_spec_and_deployment_config(
pipeline_args=args_list_with_defaults,
pipeline=dsl_pipeline,
)

if pipeline_root:
pipeline_spec.default_pipeline_root = pipeline_root

return pipeline_spec
45 changes: 44 additions & 1 deletion sdk/python/kfp/compiler/compiler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import collections
import json
import os
import re
import subprocess
import tempfile
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, NamedTuple, Optional
import unittest

from absl.testing import parameterized
Expand Down Expand Up @@ -681,6 +682,48 @@ def my_pipeline():
self.assertTrue('exec-print-op' in
pipeline_spec['deploymentSpec']['executors'])

def test_pipeline_with_invalid_output(self):
with self.assertRaisesRegex(ValueError,
'Pipeline output not defined: msg1'):

@dsl.component
def print_op(msg: str) -> str:
print(msg)

@dsl.pipeline
def my_pipeline() -> NamedTuple('Outputs', [
('msg', str),
]):
task = print_op(msg='Hello')
output = collections.namedtuple('Outputs', ['msg1'])
return output(task.output)

def test_pipeline_with_missing_output(self):
with self.assertRaisesRegex(ValueError, 'Missing pipeline output: msg'):

@dsl.component
def print_op(msg: str) -> str:
print(msg)

@dsl.pipeline
def my_pipeline() -> NamedTuple('Outputs', [
('msg', str),
]):
task = print_op(msg='Hello')

with self.assertRaisesRegex(ValueError,
'Missing pipeline output: model'):

@dsl.component
def print_op(msg: str) -> str:
print(msg)

@dsl.pipeline
def my_pipeline() -> NamedTuple('Outputs', [
('model', dsl.Model),
]):
task = print_op(msg='Hello')


class V2NamespaceAliasTest(unittest.TestCase):
"""Test that imports of both modules and objects are aliased (e.g. all
Expand Down
Loading

0 comments on commit 48574dc

Please sign in to comment.