-
Notifications
You must be signed in to change notification settings - Fork 34
/
node_hook.py
143 lines (124 loc) · 5.5 KB
/
node_hook.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
import logging
from typing import Any, Dict, Union
import mlflow
from kedro.framework.hooks import hook_impl
from kedro.io import DataCatalog
from kedro.pipeline import Pipeline
from kedro.pipeline.node import Node
from mlflow.utils.validation import MAX_PARAM_VAL_LENGTH
from kedro_mlflow.framework.context import get_mlflow_config
from kedro_mlflow.framework.hooks.utils import _assert_mlflow_enabled
class MlflowNodeHook:
def __init__(self):
self.flatten = False
self.recursive = True
self.sep = "."
self.long_parameters_strategy = "fail"
self._is_mlflow_enabled = True
@property
def _logger(self) -> logging.Logger:
return logging.getLogger(__name__)
@hook_impl
def before_pipeline_run(
self, run_params: Dict[str, Any], pipeline: Pipeline, catalog: DataCatalog
) -> None:
"""Hook to be invoked before a pipeline runs.
Args:
run_params: The params needed for the given run.
Should be identical to the data logged by Journal.
# @fixme: this needs to be modelled explicitly as code, instead of comment
Schema: {
"run_id": str,
"project_path": str,
"env": str,
"kedro_version": str,
"tags": Optional[List[str]],
"from_nodes": Optional[List[str]],
"to_nodes": Optional[List[str]],
"node_names": Optional[List[str]],
"from_inputs": Optional[List[str]],
"load_versions": Optional[List[str]],
"pipeline_name": str,
"extra_params": Optional[Dict[str, Any]],
}
pipeline: The ``Pipeline`` that will be run.
catalog: The ``DataCatalog`` to be used during the run.
"""
self._is_mlflow_enabled = _assert_mlflow_enabled(run_params["pipeline_name"])
if self._is_mlflow_enabled:
config = get_mlflow_config()
self.flatten = config.node_hook_opts["flatten_dict_params"]
self.recursive = config.node_hook_opts["recursive"]
self.sep = config.node_hook_opts["sep"]
self.long_parameters_strategy = config.node_hook_opts[
"long_parameters_strategy"
]
@hook_impl
def before_node_run(
self,
node: Node,
catalog: DataCatalog,
inputs: Dict[str, Any],
is_async: bool,
run_id: str,
) -> None:
"""Hook to be invoked before a node runs.
This hook logs all the parameters of the nodes in mlflow.
Args:
node: The ``Node`` to run.
catalog: A ``DataCatalog`` containing the node's inputs and outputs.
inputs: The dictionary of inputs dataset.
is_async: Whether the node was run in ``async`` mode.
run_id: The id of the run.
"""
# only parameters will be logged. Artifacts must be declared manually in the catalog
if self._is_mlflow_enabled:
params_inputs = {}
for k, v in inputs.items():
# detect parameters automatically based on kedro reserved names
if k.startswith("params:"):
params_inputs[k[7:]] = v
elif k == "parameters":
params_inputs[k] = v
# dictionary parameters may be flattened for readibility
if self.flatten:
params_inputs = flatten_dict(
d=params_inputs, recursive=self.recursive, sep=self.sep
)
# logging parameters based on defined strategy
for k, v in params_inputs.items():
self.log_param(k, v)
def log_param(self, name: str, value: Union[Dict, int, bool, str]) -> None:
str_value = str(value)
str_value_length = len(str_value)
if str_value_length <= MAX_PARAM_VAL_LENGTH:
return mlflow.log_param(name, value)
else:
if self.long_parameters_strategy == "fail":
raise ValueError(
f"Parameter '{name}' length is {str_value_length}, "
f"while mlflow forces it to be lower than '{MAX_PARAM_VAL_LENGTH}'. "
"If you want to bypass it, try to change 'long_parameters_strategy' to"
" 'tag' or 'truncate' in the 'mlflow.yml'configuration file."
)
elif self.long_parameters_strategy == "tag":
self._logger.warning(
f"Parameter '{name}' (value length {str_value_length}) is set as a tag."
)
mlflow.set_tag(name, value)
elif self.long_parameters_strategy == "truncate":
self._logger.warning(
f"Parameter '{name}' (value length {str_value_length}) is truncated to its {MAX_PARAM_VAL_LENGTH} first characters."
)
mlflow.log_param(name, str_value[0:MAX_PARAM_VAL_LENGTH])
# this hooks instaitation is necessary for auto-registration
mlflow_node_hook = MlflowNodeHook()
def flatten_dict(d, recursive: bool = True, sep="."):
def expand(key, value):
if isinstance(value, dict):
new_value = flatten_dict(value) if recursive else value
return [(key + sep + k, v) for k, v in new_value.items()]
else:
return [(key, value)]
items = [item for k, v in d.items() for item in expand(k, v)]
return dict(items)