Skip to content

Commit

Permalink
fix bool default deserialization
Browse files Browse the repository at this point in the history
  • Loading branch information
connor-mccarthy committed Dec 29, 2022
1 parent 01e097c commit 7cb5921
Show file tree
Hide file tree
Showing 2 changed files with 166 additions and 1 deletion.
4 changes: 3 additions & 1 deletion sdk/python/kfp/components/structures.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import ast
import collections
import dataclasses
from distutils.util import strtobool
import itertools
import re
from typing import Any, Dict, List, Mapping, Optional, Union
Expand Down Expand Up @@ -647,6 +648,8 @@ def from_v1_component_spec(
type_ = spec.get('type')
optional = spec.get('optional', False) or 'default' in spec
default = spec.get('default')
if type_ == 'Boolean' and isinstance(default, str):
default = strtobool(default) == 1

if isinstance(type_, str) and type_ == 'PipelineTaskFinalStatus':
inputs[utils.sanitize_input_name(spec['name'])] = InputSpec(
Expand All @@ -655,7 +658,6 @@ def from_v1_component_spec(

elif isinstance(type_, str) and type_.lower(
) in type_utils._PARAMETER_TYPES_MAPPING:
default = spec.get('default')
type_enum = type_utils._PARAMETER_TYPES_MAPPING[type_.lower()]
ir_parameter_type_name = pipeline_spec_pb2.ParameterType.ParameterTypeEnum.Name(
type_enum)
Expand Down
163 changes: 163 additions & 0 deletions sdk/python/kfp/components/structures_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -921,5 +921,168 @@ def test_to_proto(self):
self.assertEqual(retry_policy_proto.backoff_max_duration.seconds, 3600)


class TestDeserializeV1ComponentYamlWithDefaultBool(unittest.TestCase):

def test_uppercase_T_True(self):
comp_text = textwrap.dedent("""\
name: test_bool
inputs:
- { name: val, type: Boolean, default: "True" }
implementation:
container:
image: python:3.7
command:
- sh
- -c
- |
echo $0
- { inputValue: val }
""")
comp = components.load_component_from_text(comp_text)
self.assertEqual(
comp.pipeline_spec.root.input_definitions.parameters['val']
.default_value.bool_value, True)

def test_lowercase_t_true(self):
comp_text = textwrap.dedent("""\
name: test_bool
inputs:
- { name: val, type: Boolean, default: "true" }
implementation:
container:
image: python:3.7
command:
- sh
- -c
- |
echo $0
- { inputValue: val }
""")
comp = components.load_component_from_text(comp_text)
self.assertEqual(
comp.pipeline_spec.root.input_definitions.parameters['val']
.default_value.bool_value, True)

def test_uppercase_F_False(self):
comp_text = textwrap.dedent("""\
name: test_bool
inputs:
- { name: val, type: Boolean, default: "False" }
implementation:
container:
image: python:3.7
command:
- sh
- -c
- |
echo $0
- { inputValue: val }
""")
comp = components.load_component_from_text(comp_text)
self.assertEqual(
comp.pipeline_spec.root.input_definitions.parameters['val']
.default_value.bool_value, False)

def test_lowercase_f_false(self):
comp_text = textwrap.dedent("""\
name: test_bool
inputs:
- { name: val, type: Boolean, default: "false" }
implementation:
container:
image: python:3.7
command:
- sh
- -c
- |
echo $0
- { inputValue: val }
""")
comp = components.load_component_from_text(comp_text)
self.assertEqual(
comp.pipeline_spec.root.input_definitions.parameters['val']
.default_value.bool_value, False)

def test_uppercase_T_True_no_quotes(self):
comp_text = textwrap.dedent("""\
name: test_bool
inputs:
- { name: val, type: Boolean, default: True }
implementation:
container:
image: python:3.7
command:
- sh
- -c
- |
echo $0
- { inputValue: val }
""")
comp = components.load_component_from_text(comp_text)
self.assertEqual(
comp.pipeline_spec.root.input_definitions.parameters['val']
.default_value.bool_value, True)

def test_lowercase_t_true_no_quotes(self):
comp_text = textwrap.dedent("""\
name: test_bool
inputs:
- { name: val, type: Boolean, default: true }
implementation:
container:
image: python:3.7
command:
- sh
- -c
- |
echo $0
- { inputValue: val }
""")
comp = components.load_component_from_text(comp_text)
self.assertEqual(
comp.pipeline_spec.root.input_definitions.parameters['val']
.default_value.bool_value, True)

def test_uppercase_F_False_no_quotes(self):
comp_text = textwrap.dedent("""\
name: test_bool
inputs:
- { name: val, type: Boolean, default: False }
implementation:
container:
image: python:3.7
command:
- sh
- -c
- |
echo $0
- { inputValue: val }
""")
comp = components.load_component_from_text(comp_text)
self.assertEqual(
comp.pipeline_spec.root.input_definitions.parameters['val']
.default_value.bool_value, False)

def test_lowercase_f_false_no_quotes(self):
comp_text = textwrap.dedent("""\
name: test_bool
inputs:
- { name: val, type: Boolean, default: false }
implementation:
container:
image: python:3.7
command:
- sh
- -c
- |
echo $0
- { inputValue: val }
""")
comp = components.load_component_from_text(comp_text)
self.assertEqual(
comp.pipeline_spec.root.input_definitions.parameters['val']
.default_value.bool_value, False)


if __name__ == '__main__':
unittest.main()

0 comments on commit 7cb5921

Please sign in to comment.