Skip to content

Commit

Permalink
minor clean up
Browse files Browse the repository at this point in the history
Signed-off-by: KevinGrantLee <kglee@google.com>
  • Loading branch information
KevinGrantLee committed Jun 13, 2024
1 parent eb8a587 commit 1c70801
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 67 deletions.
129 changes: 63 additions & 66 deletions sdk/python/kfp/compiler/pipeline_spec_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,76 +239,73 @@ def build_task_spec_for_task(
component_input_parameter)

elif isinstance(input_value, (str, int, float, bool, dict, list)):
if isinstance(input_value, (str, dict, list)):
pipeline_channels = (
pipeline_channel.extract_pipeline_channels_from_any(
input_value))
for channel in pipeline_channels:
# NOTE: case like this p3 = print_and_return_str(s='Project = {}'.format(project))
# triggers this code

# value contains PipelineChannel placeholders which needs to be
# replaced. And the input needs to be added to the task spec.

# Form the name for the compiler injected input, and make sure it
# doesn't collide with any existing input names.
additional_input_name = (
compiler_utils
.additional_input_name_for_pipeline_channel(channel))

# We don't expect collision to happen because we prefix the name
# of additional input with 'pipelinechannel--'. But just in case
# collision did happend, throw a RuntimeError so that we don't
# get surprise at runtime.
for existing_input_name, _ in task.inputs.items():
if existing_input_name == additional_input_name:
raise RuntimeError(
f'Name collision between existing input name {existing_input_name} and compiler injected input name {additional_input_name}'
)

additional_input_placeholder = placeholders.InputValuePlaceholder(
additional_input_name)._to_string()

if isinstance(input_value, str):
input_value = input_value.replace(
channel.pattern, additional_input_placeholder)
else:
input_value = compiler_utils.recursive_replace_placeholders(
input_value, channel.pattern,
additional_input_placeholder)

if channel.task_name:
# Value is produced by an upstream task.
if channel.task_name in tasks_in_current_dag:
# Dependent task within the same DAG.
pipeline_task_spec.inputs.parameters[
additional_input_name].task_output_parameter.producer_task = (
utils.sanitize_task_name(channel.task_name))
pipeline_task_spec.inputs.parameters[
additional_input_name].task_output_parameter.output_parameter_key = (
channel.name)
else:
# Dependent task not from the same DAG.
component_input_parameter = (
compiler_utils
.additional_input_name_for_pipeline_channel(
channel))
assert component_input_parameter in parent_component_inputs.parameters, \
f'component_input_parameter: {component_input_parameter} not found. All inputs: {parent_component_inputs}'
pipeline_task_spec.inputs.parameters[
additional_input_name].component_input_parameter = (
component_input_parameter)
pipeline_channels = (
pipeline_channel.extract_pipeline_channels_from_any(input_value)
)
for channel in pipeline_channels:
# NOTE: case like this p3 = print_and_return_str(s='Project = {}'.format(project))
# triggers this code

# value contains PipelineChannel placeholders which needs to be
# replaced. And the input needs to be added to the task spec.

# Form the name for the compiler injected input, and make sure it
# doesn't collide with any existing input names.
additional_input_name = (
compiler_utils.additional_input_name_for_pipeline_channel(
channel))

# We don't expect collision to happen because we prefix the name
# of additional input with 'pipelinechannel--'. But just in case
# collision did happend, throw a RuntimeError so that we don't
# get surprise at runtime.
for existing_input_name, _ in task.inputs.items():
if existing_input_name == additional_input_name:
raise RuntimeError(
f'Name collision between existing input name {existing_input_name} and compiler injected input name {additional_input_name}'
)

additional_input_placeholder = placeholders.InputValuePlaceholder(
additional_input_name)._to_string()

if isinstance(input_value, str):
input_value = input_value.replace(
channel.pattern, additional_input_placeholder)
else:
input_value = compiler_utils.recursive_replace_placeholders(
input_value, channel.pattern,
additional_input_placeholder)

if channel.task_name:
# Value is produced by an upstream task.
if channel.task_name in tasks_in_current_dag:
# Dependent task within the same DAG.
pipeline_task_spec.inputs.parameters[
additional_input_name].task_output_parameter.producer_task = (
utils.sanitize_task_name(channel.task_name))
pipeline_task_spec.inputs.parameters[
additional_input_name].task_output_parameter.output_parameter_key = (
channel.name)
else:
# Value is from pipeline input. (or loop?)
component_input_parameter = channel.full_name
if component_input_parameter not in parent_component_inputs.parameters:
component_input_parameter = (
compiler_utils
.additional_input_name_for_pipeline_channel(
channel))
# Dependent task not from the same DAG.
component_input_parameter = (
compiler_utils.
additional_input_name_for_pipeline_channel(channel))
assert component_input_parameter in parent_component_inputs.parameters, \
f'component_input_parameter: {component_input_parameter} not found. All inputs: {parent_component_inputs}'
pipeline_task_spec.inputs.parameters[
additional_input_name].component_input_parameter = (
component_input_parameter)
else:
# Value is from pipeline input. (or loop?)
component_input_parameter = channel.full_name
if component_input_parameter not in parent_component_inputs.parameters:
component_input_parameter = (
compiler_utils.
additional_input_name_for_pipeline_channel(channel))
pipeline_task_spec.inputs.parameters[
additional_input_name].component_input_parameter = (
component_input_parameter)

pipeline_task_spec.inputs.parameters[
input_name].runtime_value.constant.CopyFrom(
Expand Down
2 changes: 1 addition & 1 deletion sdk/python/kfp/dsl/pipeline_channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -586,7 +586,7 @@ def extract_pipeline_channels_from_string(


def extract_pipeline_channels_from_any(
payload: Union[PipelineChannel, str, list, tuple, dict]
payload: Union[PipelineChannel, str, int, float, bool, list, tuple, dict]
) -> List[PipelineChannel]:
"""Recursively extract PipelineChannels from any object or list of objects.
Expand Down

0 comments on commit 1c70801

Please sign in to comment.