From 486bd183fffc823757f7f3c41cf6dca7d7586e1a Mon Sep 17 00:00:00 2001 From: connor-mccarthy Date: Tue, 19 Dec 2023 22:24:49 -0500 Subject: [PATCH] add special dsl.OutputPath read logic --- sdk/python/kfp/local/executor_output_utils.py | 38 +++++++++----- .../kfp/local/executor_output_utils_test.py | 49 ++++++++++++++++--- 2 files changed, 69 insertions(+), 18 deletions(-) diff --git a/sdk/python/kfp/local/executor_output_utils.py b/sdk/python/kfp/local/executor_output_utils.py index b919a6029be..70d184451ff 100644 --- a/sdk/python/kfp/local/executor_output_utils.py +++ b/sdk/python/kfp/local/executor_output_utils.py @@ -99,18 +99,32 @@ def get_outputs_from_executor_output( return {**output_parameters, **output_artifacts} -def special_dsl_outputpath_read(output_file: str, is_string: bool) -> Any: +def special_dsl_outputpath_read( + parameter_name: str, + output_file: str, + dtype: pipeline_spec_pb2.ParameterType.ParameterTypeEnum, +) -> Any: """Reads the text in dsl.OutputPath files in the same way as the remote backend. - Basically deserialize all types as JSON, but also support strings - that are written directly without quotes (e.g., `foo` instead of - `"foo"`). + In brief: read strings as strings and JSON load everything else. """ - with open(output_file) as f: - parameter_value = f.read() - # TODO: verify this is the correct special handling of OutputPath - return parameter_value if is_string else json.loads(parameter_value) + try: + with open(output_file) as f: + value = f.read() + + if dtype == pipeline_spec_pb2.ParameterType.ParameterTypeEnum.STRING: + value = value + elif dtype == pipeline_spec_pb2.ParameterType.ParameterTypeEnum.BOOLEAN: + # permit true/True and false/False, consistent with remote BE + value = json.loads(value.lower()) + else: + value = json.loads(value) + return value + except Exception as e: + raise ValueError( + f'Could not deserialize output {parameter_name!r} from path {output_file}' + ) from e def merge_dsl_output_file_parameters_to_executor_output( @@ -123,11 +137,11 @@ def merge_dsl_output_file_parameters_to_executor_output( for parameter_key, output_parameter in executor_input.outputs.parameters.items( ): if os.path.exists(output_parameter.output_file): - is_string = component_spec.output_definitions.parameters[ - parameter_key].parameter_type == pipeline_spec_pb2.ParameterType.ParameterTypeEnum.STRING parameter_value = special_dsl_outputpath_read( - output_parameter.output_file, - is_string, + parameter_name=parameter_key, + output_file=output_parameter.output_file, + dtype=component_spec.output_definitions + .parameters[parameter_key].parameter_type, ) executor_output.parameter_values[parameter_key].CopyFrom( pipeline_spec_builder.to_protobuf_value(parameter_value)) diff --git a/sdk/python/kfp/local/executor_output_utils_test.py b/sdk/python/kfp/local/executor_output_utils_test.py index c39f2d92539..ab509a40b15 100644 --- a/sdk/python/kfp/local/executor_output_utils_test.py +++ b/sdk/python/kfp/local/executor_output_utils_test.py @@ -13,6 +13,7 @@ # limitations under the License. """Tests for executor_output_utils.py.""" +import json import os import tempfile from typing import List @@ -580,19 +581,55 @@ def test(self): class TestSpecialDslOutputPathRead(parameterized.TestCase): - @parameterized.parameters([('foo', 'foo', True)]) - def test(self, written_string, expected_object, is_string): + @parameterized.parameters([ + ('foo', 'foo', + pipeline_spec_pb2.ParameterType.ParameterTypeEnum.STRING), + ('foo', 'foo', + pipeline_spec_pb2.ParameterType.ParameterTypeEnum.STRING), + ('true', True, + pipeline_spec_pb2.ParameterType.ParameterTypeEnum.BOOLEAN), + ('True', True, + pipeline_spec_pb2.ParameterType.ParameterTypeEnum.BOOLEAN), + ('false', False, + pipeline_spec_pb2.ParameterType.ParameterTypeEnum.BOOLEAN), + ('False', False, + pipeline_spec_pb2.ParameterType.ParameterTypeEnum.BOOLEAN), + (json.dumps({'x': 'y'}), { + 'x': 'y' + }, pipeline_spec_pb2.ParameterType.ParameterTypeEnum.STRUCT), + ('3.14', 3.14, + pipeline_spec_pb2.ParameterType.ParameterTypeEnum.NUMBER_DOUBLE), + ('100', 100, + pipeline_spec_pb2.ParameterType.ParameterTypeEnum.NUMBER_INTEGER), + ]) + def test(self, written, expected, dtype): with tempfile.TemporaryDirectory() as tempdir: output_file = os.path.join(tempdir, 'Output') with open(output_file, 'w') as f: - f.write(written_string) + f.write(written) actual = executor_output_utils.special_dsl_outputpath_read( - output_file, - is_string=is_string, + parameter_name='name', + output_file=output_file, + dtype=dtype, ) - self.assertEqual(actual, expected_object) + self.assertEqual(actual, expected) + + def test_exception(self): + with tempfile.TemporaryDirectory() as tempdir: + output_file = os.path.join(tempdir, 'Output') + with open(output_file, 'w') as f: + f.write(str({'x': 'y'})) + with self.assertRaisesRegex( + ValueError, + r"Could not deserialize output 'name' from path"): + executor_output_utils.special_dsl_outputpath_read( + parameter_name='name', + output_file=output_file, + dtype=pipeline_spec_pb2.ParameterType.ParameterTypeEnum + .STRUCT, + ) def assert_artifacts_equal(