Skip to content

Commit

Permalink
FIX #69 - Handle parameters which exceed mlflow limit of 250 characters
Browse files Browse the repository at this point in the history
  • Loading branch information
Galileo-Galilei committed Dec 16, 2020
1 parent c030ce9 commit 4434464
Show file tree
Hide file tree
Showing 5 changed files with 283 additions and 9 deletions.
24 changes: 22 additions & 2 deletions kedro_mlflow/framework/context/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,14 @@ class KedroMlflowConfig:

UI_OPTS = {"port": None, "host": None}

NODE_HOOK_OPTS = {"flatten_dict_params": False, "recursive": True, "sep": "."}
NODE_HOOK_OPTS = {
"flatten_dict_params": False,
"recursive": True,
"sep": ".",
"long_parameters_strategy": "fail",
}

AVAILABLE_LONG_PARAMETERS_STRATEGY = ["fail", "truncate", "tag"]

def __init__(
self,
Expand Down Expand Up @@ -134,8 +141,21 @@ def from_dict(self, configuration: Dict[str, str]):
opts=node_hook_opts, default=self.NODE_HOOK_OPTS
)

# this parameter validation should likely be elsewhere
# when refactoring KedroMlflowConfig, we shoudl use an object
# to validate data with a @property
if (
self.node_hook_opts["long_parameters_strategy"]
not in self.AVAILABLE_LONG_PARAMETERS_STRATEGY
):
strategy_list = ", ".join(self.AVAILABLE_LONG_PARAMETERS_STRATEGY)
raise ValueError(
f"'long_parameters_strategy' must be one of [{strategy_list}], "
f"got '{self.node_hook_opts['long_parameters_strategy']}'"
)

# instantiate mlflow objects to interact with the database
# the client must not be create dbefore carefully checking the uri,
# the client must not be created before carefully checking the uri,
# otherwise mlflow creates a mlruns folder to the current location
self.mlflow_client = mlflow.tracking.MlflowClient(
tracking_uri=self.mlflow_tracking_uri
Expand Down
45 changes: 42 additions & 3 deletions kedro_mlflow/framework/hooks/node_hook.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
from typing import Any, Dict
import logging
from typing import Any, Dict, Union

import mlflow
from kedro.framework.context import load_context
from kedro.framework.hooks import hook_impl
from kedro.io import DataCatalog
from kedro.pipeline import Pipeline
from kedro.pipeline.node import Node
from mlflow.utils.validation import MAX_PARAM_VAL_LENGTH

from kedro_mlflow.framework.context import get_mlflow_config

Expand All @@ -16,6 +18,11 @@ def __init__(self):
self.flatten = False
self.recursive = True
self.sep = "."
self.long_parameters_strategy = "fail"

@property
def _logger(self) -> logging.Logger:
return logging.getLogger(__name__)

@hook_impl
def before_pipeline_run(
Expand Down Expand Up @@ -50,9 +57,13 @@ def before_pipeline_run(
extra_params=run_params["extra_params"],
)
config = get_mlflow_config(self.context)

self.flatten = config.node_hook_opts["flatten_dict_params"]
self.recursive = config.node_hook_opts["recursive"]
self.sep = config.node_hook_opts["sep"]
self.long_parameters_strategy = config.node_hook_opts[
"long_parameters_strategy"
]

@hook_impl
def before_node_run(
Expand All @@ -76,6 +87,7 @@ def before_node_run(
# only parameters will be logged. Artifacts must be declared manually in the catalog
params_inputs = {}
for k, v in inputs.items():
# detect parameters automatically based on kedro reserved names
if k.startswith("params:"):
params_inputs[k[7:]] = v
elif k == "parameters":
Expand All @@ -87,9 +99,36 @@ def before_node_run(
d=params_inputs, recursive=self.recursive, sep=self.sep
)

mlflow.log_params(params_inputs)

# logging parameters based on defined strategy
for k, v in params_inputs.items():
self.log_param(k, v)

def log_param(self, name: str, value: Union[Dict, int, bool, str]) -> None:
str_value = str(value)
str_value_length = len(str_value)
if str_value_length <= MAX_PARAM_VAL_LENGTH:
return mlflow.log_param(name, value)
else:
if self.long_parameters_strategy == "fail":
raise ValueError(
f"Parameter '{name}' length is {str_value_length}, "
f"while mlflow forces it to be lower than '{MAX_PARAM_VAL_LENGTH}'. "
"If you want to bypass it, try to change 'long_parameters_strategy' to"
" 'tag' or 'truncate' in the 'mlflow.yml'configuration file."
)
elif self.long_parameters_strategy == "tag":
self._logger.warning(
f"Parameter '{name}' (value length {str_value_length}) is set as a tag."
)
mlflow.set_tag(name, value)
elif self.long_parameters_strategy == "truncate":
self._logger.warning(
f"Parameter '{name}' (value length {str_value_length}) is truncated to its {MAX_PARAM_VAL_LENGTH} first characters."
)
mlflow.log_param(name, str_value[0:MAX_PARAM_VAL_LENGTH])


# this hooks instaitation is necessary for auto-registration
mlflow_node_hook = MlflowNodeHook()


Expand Down
1 change: 1 addition & 0 deletions kedro_mlflow/template/project/mlflow.yml
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ hooks:
flatten_dict_params: False # if True, parameter which are dictionary will be splitted in multiple parameters when logged in mlflow, one for each key.
recursive: True # Should the dictionary flattening be applied recursively (i.e for nested dictionaries)? Not use if `flatten_dict_params` is False.
sep: "." # In case of recursive flattening, what separator should be used between the keys? E.g. {hyperaparam1: {p1:1, p2:2}}will be logged as hyperaparam1.p1 and hyperaparam1.p2 oin mlflow.
long_parameters_strategy: fail # One of ["fail", "tag", "truncate" ] If a parameter is above mlflow limit (currently 250), what should kedro-mlflow do? -> fail, set as a tag instead of a parameter, or truncate it to its 250 first letters?


# UI-RELATED PARAMETERS -----------------
Expand Down
32 changes: 28 additions & 4 deletions tests/framework/context/test_mlflow_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,14 @@ def test_get_mlflow_config(mocker, tmp_path, config_dir):
experiment=dict(name="fake_package", create=True),
run=dict(id="123456789", name="my_run", nested=True),
ui=dict(port="5151", host="localhost"),
hooks=dict(node=dict(flatten_dict_params=True, recursive=False, sep="-")),
hooks=dict(
node=dict(
flatten_dict_params=True,
recursive=False,
sep="-",
long_parameters_strategy="truncate",
)
),
),
)
expected = {
Expand All @@ -37,7 +44,12 @@ def test_get_mlflow_config(mocker, tmp_path, config_dir):
"run": {"id": "123456789", "name": "my_run", "nested": True},
"ui": {"port": "5151", "host": "localhost"},
"hooks": {
"node": {"flatten_dict_params": True, "recursive": False, "sep": "-"}
"node": {
"flatten_dict_params": True,
"recursive": False,
"sep": "-",
"long_parameters_strategy": "truncate",
}
},
}
context = load_context(tmp_path)
Expand All @@ -54,7 +66,14 @@ def test_mlflow_config_with_templated_config(mocker, tmp_path, config_dir):
experiment=dict(name="fake_package", create=True),
run=dict(id="123456789", name="my_run", nested=True),
ui=dict(port="5151", host="localhost"),
hooks=dict(node=dict(flatten_dict_params=True, recursive=False, sep="-")),
hooks=dict(
node=dict(
flatten_dict_params=True,
recursive=False,
sep="-",
long_parameters_strategy="truncate",
)
),
),
)

Expand All @@ -70,7 +89,12 @@ def test_mlflow_config_with_templated_config(mocker, tmp_path, config_dir):
"run": {"id": "123456789", "name": "my_run", "nested": True},
"ui": {"port": "5151", "host": "localhost"},
"hooks": {
"node": {"flatten_dict_params": True, "recursive": False, "sep": "-"}
"node": {
"flatten_dict_params": True,
"recursive": False,
"sep": "-",
"long_parameters_strategy": "truncate",
}
},
}

Expand Down
190 changes: 190 additions & 0 deletions tests/framework/hooks/test_node_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from kedro.io import DataCatalog, MemoryDataSet
from kedro.pipeline import Pipeline, node
from mlflow.tracking import MlflowClient
from mlflow.utils.validation import MAX_PARAM_VAL_LENGTH

from kedro_mlflow.framework.hooks import MlflowNodeHook
from kedro_mlflow.framework.hooks.node_hook import flatten_dict
Expand Down Expand Up @@ -196,3 +197,192 @@ def test_node_hook_logging(
mlflow_client = MlflowClient(mlflow_tracking_uri)
current_run = mlflow_client.get_run(run_id)
assert current_run.data.params == expected


@pytest.mark.parametrize(
"param_length", [MAX_PARAM_VAL_LENGTH - 10, MAX_PARAM_VAL_LENGTH]
)
@pytest.mark.parametrize("strategy", ["fail", "truncate", "tag"])
def test_node_hook_logging_below_limit_all_strategy(
tmp_path, config_dir, dummy_run_params, dummy_node, param_length, strategy
):

# mocker.patch("kedro_mlflow.utils._is_kedro_project", return_value=True)

_write_yaml(
tmp_path / "conf" / "base" / "mlflow.yml",
dict(
hooks=dict(node=dict(long_parameters_strategy=strategy)),
),
)

mlflow_tracking_uri = (tmp_path / "mlruns").as_uri()
mlflow.set_tracking_uri(mlflow_tracking_uri)

mlflow_node_hook = MlflowNodeHook()

param_value = param_length * "a"
node_inputs = {"params:my_param": param_value}

with mlflow.start_run():
mlflow_node_hook.before_pipeline_run(
run_params=dummy_run_params, pipeline=Pipeline([]), catalog=DataCatalog()
)
mlflow_node_hook.before_node_run(
node=node(func=lambda x: x, inputs=dict(x="a"), outputs=None),
catalog=DataCatalog(), # can be empty
inputs=node_inputs,
is_async=False,
run_id="132",
)
run_id = mlflow.active_run().info.run_id

mlflow_client = MlflowClient(mlflow_tracking_uri)
current_run = mlflow_client.get_run(run_id)
assert current_run.data.params == {"my_param": param_value}


@pytest.mark.parametrize(
"param_length",
[MAX_PARAM_VAL_LENGTH + 20],
)
def test_node_hook_logging_above_limit_truncate_strategy(
tmp_path, config_dir, dummy_run_params, dummy_node, param_length
):

# mocker.patch("kedro_mlflow.utils._is_kedro_project", return_value=True)

_write_yaml(
tmp_path / "conf" / "base" / "mlflow.yml",
dict(
hooks=dict(node=dict(long_parameters_strategy="truncate")),
),
)

mlflow_tracking_uri = (tmp_path / "mlruns").as_uri()
mlflow.set_tracking_uri(mlflow_tracking_uri)

mlflow_node_hook = MlflowNodeHook()

param_value = param_length * "a"
node_inputs = {"params:my_param": param_value}

with mlflow.start_run():
mlflow_node_hook.before_pipeline_run(
run_params=dummy_run_params, pipeline=Pipeline([]), catalog=DataCatalog()
)
mlflow_node_hook.before_node_run(
node=node(func=lambda x: x, inputs=dict(x="a"), outputs=None),
catalog=DataCatalog(), # can be empty
inputs=node_inputs,
is_async=False,
run_id="132",
)
run_id = mlflow.active_run().info.run_id

mlflow_client = MlflowClient(mlflow_tracking_uri)
current_run = mlflow_client.get_run(run_id)
assert current_run.data.params == {"my_param": param_value[0:MAX_PARAM_VAL_LENGTH]}


@pytest.mark.parametrize(
"param_length",
[MAX_PARAM_VAL_LENGTH + 20],
)
def test_node_hook_logging_above_limit_fail_strategy(
tmp_path, config_dir, dummy_run_params, dummy_node, param_length
):

# mocker.patch("kedro_mlflow.utils._is_kedro_project", return_value=True)

_write_yaml(
tmp_path / "conf" / "base" / "mlflow.yml",
dict(
hooks=dict(node=dict(long_parameters_strategy="fail")),
),
)

mlflow_tracking_uri = (tmp_path / "mlruns").as_uri()
mlflow.set_tracking_uri(mlflow_tracking_uri)

mlflow_node_hook = MlflowNodeHook()

param_value = param_length * "a"
node_inputs = {"params:my_param": param_value}

with mlflow.start_run():
mlflow_node_hook.before_pipeline_run(
run_params=dummy_run_params, pipeline=Pipeline([]), catalog=DataCatalog()
)

# IMPORTANT: Overpassing the parameters limit
# should raise an error for all mlflow backend
# but it does not on FileStore backend :
# https://github.com/mlflow/mlflow/issues/2814#issuecomment-628284425
# Since we use FileStore system for simplicty for tests logging works
# But we have enforced failure (which is slightly different from mlflow
# behaviour)
with pytest.raises(
ValueError, match=f"Parameter 'my_param' length is {param_length}"
):
mlflow_node_hook.before_node_run(
node=node(func=lambda x: x, inputs=dict(x="a"), outputs=None),
catalog=DataCatalog(), # can be empty
inputs=node_inputs,
is_async=False,
run_id="132",
)


@pytest.mark.parametrize(
"param_length",
[MAX_PARAM_VAL_LENGTH + 20],
)
def test_node_hook_logging_above_limit_tag_strategy(
tmp_path, config_dir, dummy_run_params, dummy_node, param_length
):

# mocker.patch("kedro_mlflow.utils._is_kedro_project", return_value=True)

_write_yaml(
tmp_path / "conf" / "base" / "mlflow.yml",
dict(
hooks=dict(node=dict(long_parameters_strategy="tag")),
),
)

mlflow_tracking_uri = (tmp_path / "mlruns").as_uri()
mlflow.set_tracking_uri(mlflow_tracking_uri)

mlflow_node_hook = MlflowNodeHook()

param_value = param_length * "a"
node_inputs = {"params:my_param": param_value}

with mlflow.start_run():
mlflow_node_hook.before_pipeline_run(
run_params=dummy_run_params, pipeline=Pipeline([]), catalog=DataCatalog()
)

# IMPORTANT: Overpassing the parameters limit
# should raise an error for all mlflow backend
# but it does not on FileStore backend :
# https://github.com/mlflow/mlflow/issues/2814#issuecomment-628284425
# Since we use FileStore system for simplicty for tests logging works
# But we have enforced failure (which is slightly different from mlflow
# behaviour)
mlflow_node_hook.before_node_run(
node=node(func=lambda x: x, inputs=dict(x="a"), outputs=None),
catalog=DataCatalog(), # can be empty
inputs=node_inputs,
is_async=False,
run_id="132",
)
run_id = mlflow.active_run().info.run_id

mlflow_client = MlflowClient(mlflow_tracking_uri)
current_run = mlflow_client.get_run(run_id)
assert current_run.data.params == {}
assert {
k: v for k, v in current_run.data.tags.items() if not k.startswith("mlflow")
} == {"my_param": param_value}

0 comments on commit 4434464

Please sign in to comment.