diff --git a/kedro_mlflow/framework/context/config.py b/kedro_mlflow/framework/context/config.py index 1948613d..802b7b90 100644 --- a/kedro_mlflow/framework/context/config.py +++ b/kedro_mlflow/framework/context/config.py @@ -18,7 +18,14 @@ class KedroMlflowConfig: UI_OPTS = {"port": None, "host": None} - NODE_HOOK_OPTS = {"flatten_dict_params": False, "recursive": True, "sep": "."} + NODE_HOOK_OPTS = { + "flatten_dict_params": False, + "recursive": True, + "sep": ".", + "long_parameters_strategy": "fail", + } + + AVAILABLE_LONG_PARAMETERS_STRATEGY = ["fail", "truncate", "tag"] def __init__( self, @@ -134,8 +141,21 @@ def from_dict(self, configuration: Dict[str, str]): opts=node_hook_opts, default=self.NODE_HOOK_OPTS ) + # this parameter validation should likely be elsewhere + # when refactoring KedroMlflowConfig, we shoudl use an object + # to validate data with a @property + if ( + self.node_hook_opts["long_parameters_strategy"] + not in self.AVAILABLE_LONG_PARAMETERS_STRATEGY + ): + strategy_list = ", ".join(self.AVAILABLE_LONG_PARAMETERS_STRATEGY) + raise ValueError( + f"'long_parameters_strategy' must be one of [{strategy_list}], " + f"got '{self.node_hook_opts['long_parameters_strategy']}'" + ) + # instantiate mlflow objects to interact with the database - # the client must not be create dbefore carefully checking the uri, + # the client must not be created before carefully checking the uri, # otherwise mlflow creates a mlruns folder to the current location self.mlflow_client = mlflow.tracking.MlflowClient( tracking_uri=self.mlflow_tracking_uri diff --git a/kedro_mlflow/framework/hooks/node_hook.py b/kedro_mlflow/framework/hooks/node_hook.py index ea200951..67ff364f 100644 --- a/kedro_mlflow/framework/hooks/node_hook.py +++ b/kedro_mlflow/framework/hooks/node_hook.py @@ -1,4 +1,5 @@ -from typing import Any, Dict +import logging +from typing import Any, Dict, Union import mlflow from kedro.framework.context import load_context @@ -6,6 +7,7 @@ from kedro.io import DataCatalog from kedro.pipeline import Pipeline from kedro.pipeline.node import Node +from mlflow.utils.validation import MAX_PARAM_VAL_LENGTH from kedro_mlflow.framework.context import get_mlflow_config @@ -16,6 +18,11 @@ def __init__(self): self.flatten = False self.recursive = True self.sep = "." + self.long_parameters_strategy = "fail" + + @property + def _logger(self) -> logging.Logger: + return logging.getLogger(__name__) @hook_impl def before_pipeline_run( @@ -50,9 +57,13 @@ def before_pipeline_run( extra_params=run_params["extra_params"], ) config = get_mlflow_config(self.context) + self.flatten = config.node_hook_opts["flatten_dict_params"] self.recursive = config.node_hook_opts["recursive"] self.sep = config.node_hook_opts["sep"] + self.long_parameters_strategy = config.node_hook_opts[ + "long_parameters_strategy" + ] @hook_impl def before_node_run( @@ -76,6 +87,7 @@ def before_node_run( # only parameters will be logged. Artifacts must be declared manually in the catalog params_inputs = {} for k, v in inputs.items(): + # detect parameters automatically based on kedro reserved names if k.startswith("params:"): params_inputs[k[7:]] = v elif k == "parameters": @@ -87,9 +99,36 @@ def before_node_run( d=params_inputs, recursive=self.recursive, sep=self.sep ) - mlflow.log_params(params_inputs) - + # logging parameters based on defined strategy + for k, v in params_inputs.items(): + self.log_param(k, v) + def log_param(self, name: str, value: Union[Dict, int, bool, str]) -> None: + str_value = str(value) + str_value_length = len(str_value) + if str_value_length <= MAX_PARAM_VAL_LENGTH: + return mlflow.log_param(name, value) + else: + if self.long_parameters_strategy == "fail": + raise ValueError( + f"Parameter '{name}' length is {str_value_length}, " + f"while mlflow forces it to be lower than '{MAX_PARAM_VAL_LENGTH}'. " + "If you want to bypass it, try to change 'long_parameters_strategy' to" + " 'tag' or 'truncate' in the 'mlflow.yml'configuration file." + ) + elif self.long_parameters_strategy == "tag": + self._logger.warning( + f"Parameter '{name}' (value length {str_value_length}) is set as a tag." + ) + mlflow.set_tag(name, value) + elif self.long_parameters_strategy == "truncate": + self._logger.warning( + f"Parameter '{name}' (value length {str_value_length}) is truncated to its {MAX_PARAM_VAL_LENGTH} first characters." + ) + mlflow.log_param(name, str_value[0:MAX_PARAM_VAL_LENGTH]) + + +# this hooks instaitation is necessary for auto-registration mlflow_node_hook = MlflowNodeHook() diff --git a/kedro_mlflow/template/project/mlflow.yml b/kedro_mlflow/template/project/mlflow.yml index c3fb8e89..d986f603 100644 --- a/kedro_mlflow/template/project/mlflow.yml +++ b/kedro_mlflow/template/project/mlflow.yml @@ -41,6 +41,7 @@ hooks: flatten_dict_params: False # if True, parameter which are dictionary will be splitted in multiple parameters when logged in mlflow, one for each key. recursive: True # Should the dictionary flattening be applied recursively (i.e for nested dictionaries)? Not use if `flatten_dict_params` is False. sep: "." # In case of recursive flattening, what separator should be used between the keys? E.g. {hyperaparam1: {p1:1, p2:2}}will be logged as hyperaparam1.p1 and hyperaparam1.p2 oin mlflow. + long_parameters_strategy: fail # One of ["fail", "tag", "truncate" ] If a parameter is above mlflow limit (currently 250), what should kedro-mlflow do? -> fail, set as a tag instead of a parameter, or truncate it to its 250 first letters? # UI-RELATED PARAMETERS ----------------- diff --git a/tests/framework/context/test_mlflow_context.py b/tests/framework/context/test_mlflow_context.py index cb06b426..f56acf9d 100644 --- a/tests/framework/context/test_mlflow_context.py +++ b/tests/framework/context/test_mlflow_context.py @@ -27,7 +27,14 @@ def test_get_mlflow_config(mocker, tmp_path, config_dir): experiment=dict(name="fake_package", create=True), run=dict(id="123456789", name="my_run", nested=True), ui=dict(port="5151", host="localhost"), - hooks=dict(node=dict(flatten_dict_params=True, recursive=False, sep="-")), + hooks=dict( + node=dict( + flatten_dict_params=True, + recursive=False, + sep="-", + long_parameters_strategy="truncate", + ) + ), ), ) expected = { @@ -37,7 +44,12 @@ def test_get_mlflow_config(mocker, tmp_path, config_dir): "run": {"id": "123456789", "name": "my_run", "nested": True}, "ui": {"port": "5151", "host": "localhost"}, "hooks": { - "node": {"flatten_dict_params": True, "recursive": False, "sep": "-"} + "node": { + "flatten_dict_params": True, + "recursive": False, + "sep": "-", + "long_parameters_strategy": "truncate", + } }, } context = load_context(tmp_path) @@ -54,7 +66,14 @@ def test_mlflow_config_with_templated_config(mocker, tmp_path, config_dir): experiment=dict(name="fake_package", create=True), run=dict(id="123456789", name="my_run", nested=True), ui=dict(port="5151", host="localhost"), - hooks=dict(node=dict(flatten_dict_params=True, recursive=False, sep="-")), + hooks=dict( + node=dict( + flatten_dict_params=True, + recursive=False, + sep="-", + long_parameters_strategy="truncate", + ) + ), ), ) @@ -70,7 +89,12 @@ def test_mlflow_config_with_templated_config(mocker, tmp_path, config_dir): "run": {"id": "123456789", "name": "my_run", "nested": True}, "ui": {"port": "5151", "host": "localhost"}, "hooks": { - "node": {"flatten_dict_params": True, "recursive": False, "sep": "-"} + "node": { + "flatten_dict_params": True, + "recursive": False, + "sep": "-", + "long_parameters_strategy": "truncate", + } }, } diff --git a/tests/framework/hooks/test_node_hook.py b/tests/framework/hooks/test_node_hook.py index 6997b18d..0d3d6615 100644 --- a/tests/framework/hooks/test_node_hook.py +++ b/tests/framework/hooks/test_node_hook.py @@ -7,6 +7,7 @@ from kedro.io import DataCatalog, MemoryDataSet from kedro.pipeline import Pipeline, node from mlflow.tracking import MlflowClient +from mlflow.utils.validation import MAX_PARAM_VAL_LENGTH from kedro_mlflow.framework.hooks import MlflowNodeHook from kedro_mlflow.framework.hooks.node_hook import flatten_dict @@ -196,3 +197,192 @@ def test_node_hook_logging( mlflow_client = MlflowClient(mlflow_tracking_uri) current_run = mlflow_client.get_run(run_id) assert current_run.data.params == expected + + +@pytest.mark.parametrize( + "param_length", [MAX_PARAM_VAL_LENGTH - 10, MAX_PARAM_VAL_LENGTH] +) +@pytest.mark.parametrize("strategy", ["fail", "truncate", "tag"]) +def test_node_hook_logging_below_limit_all_strategy( + tmp_path, config_dir, dummy_run_params, dummy_node, param_length, strategy +): + + # mocker.patch("kedro_mlflow.utils._is_kedro_project", return_value=True) + + _write_yaml( + tmp_path / "conf" / "base" / "mlflow.yml", + dict( + hooks=dict(node=dict(long_parameters_strategy=strategy)), + ), + ) + + mlflow_tracking_uri = (tmp_path / "mlruns").as_uri() + mlflow.set_tracking_uri(mlflow_tracking_uri) + + mlflow_node_hook = MlflowNodeHook() + + param_value = param_length * "a" + node_inputs = {"params:my_param": param_value} + + with mlflow.start_run(): + mlflow_node_hook.before_pipeline_run( + run_params=dummy_run_params, pipeline=Pipeline([]), catalog=DataCatalog() + ) + mlflow_node_hook.before_node_run( + node=node(func=lambda x: x, inputs=dict(x="a"), outputs=None), + catalog=DataCatalog(), # can be empty + inputs=node_inputs, + is_async=False, + run_id="132", + ) + run_id = mlflow.active_run().info.run_id + + mlflow_client = MlflowClient(mlflow_tracking_uri) + current_run = mlflow_client.get_run(run_id) + assert current_run.data.params == {"my_param": param_value} + + +@pytest.mark.parametrize( + "param_length", + [MAX_PARAM_VAL_LENGTH + 20], +) +def test_node_hook_logging_above_limit_truncate_strategy( + tmp_path, config_dir, dummy_run_params, dummy_node, param_length +): + + # mocker.patch("kedro_mlflow.utils._is_kedro_project", return_value=True) + + _write_yaml( + tmp_path / "conf" / "base" / "mlflow.yml", + dict( + hooks=dict(node=dict(long_parameters_strategy="truncate")), + ), + ) + + mlflow_tracking_uri = (tmp_path / "mlruns").as_uri() + mlflow.set_tracking_uri(mlflow_tracking_uri) + + mlflow_node_hook = MlflowNodeHook() + + param_value = param_length * "a" + node_inputs = {"params:my_param": param_value} + + with mlflow.start_run(): + mlflow_node_hook.before_pipeline_run( + run_params=dummy_run_params, pipeline=Pipeline([]), catalog=DataCatalog() + ) + mlflow_node_hook.before_node_run( + node=node(func=lambda x: x, inputs=dict(x="a"), outputs=None), + catalog=DataCatalog(), # can be empty + inputs=node_inputs, + is_async=False, + run_id="132", + ) + run_id = mlflow.active_run().info.run_id + + mlflow_client = MlflowClient(mlflow_tracking_uri) + current_run = mlflow_client.get_run(run_id) + assert current_run.data.params == {"my_param": param_value[0:MAX_PARAM_VAL_LENGTH]} + + +@pytest.mark.parametrize( + "param_length", + [MAX_PARAM_VAL_LENGTH + 20], +) +def test_node_hook_logging_above_limit_fail_strategy( + tmp_path, config_dir, dummy_run_params, dummy_node, param_length +): + + # mocker.patch("kedro_mlflow.utils._is_kedro_project", return_value=True) + + _write_yaml( + tmp_path / "conf" / "base" / "mlflow.yml", + dict( + hooks=dict(node=dict(long_parameters_strategy="fail")), + ), + ) + + mlflow_tracking_uri = (tmp_path / "mlruns").as_uri() + mlflow.set_tracking_uri(mlflow_tracking_uri) + + mlflow_node_hook = MlflowNodeHook() + + param_value = param_length * "a" + node_inputs = {"params:my_param": param_value} + + with mlflow.start_run(): + mlflow_node_hook.before_pipeline_run( + run_params=dummy_run_params, pipeline=Pipeline([]), catalog=DataCatalog() + ) + + # IMPORTANT: Overpassing the parameters limit + # should raise an error for all mlflow backend + # but it does not on FileStore backend : + # https://github.com/mlflow/mlflow/issues/2814#issuecomment-628284425 + # Since we use FileStore system for simplicty for tests logging works + # But we have enforced failure (which is slightly different from mlflow + # behaviour) + with pytest.raises( + ValueError, match=f"Parameter 'my_param' length is {param_length}" + ): + mlflow_node_hook.before_node_run( + node=node(func=lambda x: x, inputs=dict(x="a"), outputs=None), + catalog=DataCatalog(), # can be empty + inputs=node_inputs, + is_async=False, + run_id="132", + ) + + +@pytest.mark.parametrize( + "param_length", + [MAX_PARAM_VAL_LENGTH + 20], +) +def test_node_hook_logging_above_limit_tag_strategy( + tmp_path, config_dir, dummy_run_params, dummy_node, param_length +): + + # mocker.patch("kedro_mlflow.utils._is_kedro_project", return_value=True) + + _write_yaml( + tmp_path / "conf" / "base" / "mlflow.yml", + dict( + hooks=dict(node=dict(long_parameters_strategy="tag")), + ), + ) + + mlflow_tracking_uri = (tmp_path / "mlruns").as_uri() + mlflow.set_tracking_uri(mlflow_tracking_uri) + + mlflow_node_hook = MlflowNodeHook() + + param_value = param_length * "a" + node_inputs = {"params:my_param": param_value} + + with mlflow.start_run(): + mlflow_node_hook.before_pipeline_run( + run_params=dummy_run_params, pipeline=Pipeline([]), catalog=DataCatalog() + ) + + # IMPORTANT: Overpassing the parameters limit + # should raise an error for all mlflow backend + # but it does not on FileStore backend : + # https://github.com/mlflow/mlflow/issues/2814#issuecomment-628284425 + # Since we use FileStore system for simplicty for tests logging works + # But we have enforced failure (which is slightly different from mlflow + # behaviour) + mlflow_node_hook.before_node_run( + node=node(func=lambda x: x, inputs=dict(x="a"), outputs=None), + catalog=DataCatalog(), # can be empty + inputs=node_inputs, + is_async=False, + run_id="132", + ) + run_id = mlflow.active_run().info.run_id + + mlflow_client = MlflowClient(mlflow_tracking_uri) + current_run = mlflow_client.get_run(run_id) + assert current_run.data.params == {} + assert { + k: v for k, v in current_run.data.tags.items() if not k.startswith("mlflow") + } == {"my_param": param_value}