Skip to content

Commit

Permalink
add test for None default parameter
Browse files Browse the repository at this point in the history
  • Loading branch information
connor-mccarthy committed Jan 5, 2024
1 parent 64d46df commit 348bf7a
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 7 deletions.
10 changes: 4 additions & 6 deletions sdk/python/kfp/local/executor_output_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,17 +49,15 @@ def cast_protobuf_numbers(
struct_pb2.Value to a dict/json, int will be upcast to float, even
if the component output specifies int.
"""
int_output_keys = [
output_parameter_types = [
output_param_name
for output_param_name, parameter_spec in output_parameter_types.items()
if parameter_spec.parameter_type ==
pipeline_spec_pb2.ParameterType.ParameterTypeEnum.NUMBER_INTEGER
]
for int_output_key in int_output_keys:
# avoid KeyError when the user never writes to the dsl.OutputPath
if int_output_key in output_parameters:
output_parameters[int_output_key] = int(
output_parameters[int_output_key])
for float_output_key in output_parameter_types:
output_parameters[float_output_key] = int(
output_parameters[float_output_key])
return output_parameters


Expand Down
12 changes: 11 additions & 1 deletion sdk/python/kfp/local/subprocess_task_handler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
"""Tests for subprocess_local_task_handler.py."""
import contextlib
import io
from typing import NamedTuple
from typing import NamedTuple, Optional
import unittest
from unittest import mock

Expand Down Expand Up @@ -417,6 +417,16 @@ def my_comp(out_param: dsl.OutputPath(int)):
task = my_comp()
self.assertEmpty(task.outputs)

def test_optional_param(self):
local.init(runner=local.SubprocessRunner(use_venv=True))

@dsl.component
def my_comp(string: Optional[str] = None) -> str:
return 'is none' if string is None else 'not none'

task = my_comp()
self.assertEqual(task.output, 'is none')


if __name__ == '__main__':
unittest.main()

0 comments on commit 348bf7a

Please sign in to comment.