-
Notifications
You must be signed in to change notification settings - Fork 893
/
node.py
650 lines (561 loc) · 24.3 KB
/
node.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
# Copyright 2018-2019 QuantumBlack Visual Analytics Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES
# OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, AND
# NONINFRINGEMENT. IN NO EVENT WILL THE LICENSOR OR OTHER CONTRIBUTORS
# BE LIABLE FOR ANY CLAIM, DAMAGES, OR OTHER LIABILITY, WHETHER IN AN
# ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF, OR IN
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
#
# The QuantumBlack Visual Analytics Limited ("QuantumBlack") name and logo
# (either separately or in combination, "QuantumBlack Trademarks") are
# trademarks of QuantumBlack. The License does not grant you any right or
# license to the QuantumBlack Trademarks. You may not use the QuantumBlack
# Trademarks or any confusingly similar mark as a trademark for your product,
# or use the QuantumBlack Trademarks in any other manner that might cause
# confusion in the marketplace, including but not limited to in advertising,
# on websites, or on software.
#
# See the License for the specific language governing permissions and
# limitations under the License.
"""This module provides user-friendly functions for creating nodes as parts
of Kedro pipelines.
"""
import copy
import inspect
import logging
from collections import Counter
from functools import reduce
from typing import Any, Callable, Dict, Iterable, List, Set, Union
from warnings import warn
class Node:
"""``Node`` is an auxiliary class facilitating the operations required to
run user-provided functions as part of Kedro pipelines.
"""
# pylint: disable=missing-type-doc
def __init__(
self,
func: Callable,
inputs: Union[None, str, List[str], Dict[str, str]],
outputs: Union[None, str, List[str], Dict[str, str]],
*,
name: str = None,
tags: Iterable[str] = None,
decorators: Iterable[Callable] = None
):
"""Create a node in the pipeline by providing a function to be called
along with variable names for inputs and/or outputs.
Args:
func: A function that corresponds to the node logic.
The function should have at least one input or output.
inputs: The name or the list of the names of variables used as
inputs to the function. The number of names should match
the number of arguments in the definition of the provided
function. When Dict[str, str] is provided, variable names
will be mapped to function argument names.
outputs: The name or the list of the names of variables used
as outputs to the function. The number of names should match
the number of outputs returned by the provided function.
When Dict[str, str] is provided, variable names will be mapped
to the named outputs the function returns.
name: Optional node name to be used when displaying the node in
logs or any other visualisations.
tags: Optional set of tags to be applied to the node.
decorators: Optional list of decorators to be applied to the node.
Raises:
ValueError: Raised in the following cases:
a) When the provided arguments do not conform to
the format suggested by the type hint of the argument.
b) When the node produces multiple outputs with the same name.
c) An input has the same name as an output.
"""
if not callable(func):
raise ValueError(
_node_error_message(
"first argument must be a "
"function, not `{}`.".format(type(func).__name__)
)
)
if inputs and not isinstance(inputs, (list, dict, str)):
raise ValueError(
_node_error_message(
"`inputs` type must be one of [String, List, Dict, None], "
"not `{}`.".format(type(inputs).__name__)
)
)
if outputs and not isinstance(outputs, (list, dict, str)):
raise ValueError(
_node_error_message(
"`outputs` type must be one of [String, List, Dict, None], "
"not `{}`.".format(type(outputs).__name__)
)
)
if not inputs and not outputs:
raise ValueError(
_node_error_message("it must have some `inputs` or `outputs`.")
)
self._validate_inputs(func, inputs)
self._func = func
self._inputs = inputs
self._outputs = outputs
self._name = name
self._tags = set([] if tags is None else tags)
self._decorators = list(decorators or [])
self._validate_unique_outputs()
self._validate_inputs_dif_than_outputs()
@property
def _logger(self):
return logging.getLogger(__name__)
@property
def _unique_key(self):
def hashable(value):
if isinstance(value, dict):
return tuple(sorted(value.items()))
if isinstance(value, list):
return tuple(value)
return value
return (self.name, hashable(self._inputs), hashable(self._outputs))
def __eq__(self, other):
if not isinstance(other, Node):
return NotImplemented
return self._unique_key == other._unique_key # pylint: disable=protected-access
def __lt__(self, other):
if not isinstance(other, Node):
return NotImplemented
return self._unique_key < other._unique_key # pylint: disable=protected-access
def __hash__(self):
return hash(self._unique_key)
def __str__(self):
def _sorted_set_to_str(xset):
return "[{}]".format(",".join(sorted(xset)))
out_str = _sorted_set_to_str(self.outputs) if self._outputs else "None"
in_str = _sorted_set_to_str(self.inputs) if self._inputs else "None"
prefix = self._name + ": " if self._name else ""
return prefix + "{}({}) -> {}".format(self._func_name, in_str, out_str)
def __repr__(self): # pragma: no cover
return "Node({}, {!r}, {!r}, {!r})".format(
self._func_name, self._inputs, self._outputs, self._name
)
def __call__(self, **kwargs) -> Dict[str, Any]:
return self.run(inputs=kwargs)
@property
def _func_name(self):
if hasattr(self._func, "__name__"):
return self._func.__name__
name = repr(self._func)
if "functools.partial" in name:
warn(
"The node producing outputs `{}` is made from a `partial` function. "
"Partial functions do not have a `__name__` attribute: consider using "
"`functools.update_wrapper` for better log messages.".format(
self.outputs
)
)
name = "<partial>"
return name
@property
def tags(self) -> Set[str]:
"""Return the tags assigned to the node.
Returns:
Return the set of all assigned tags to the node.
"""
return set(self._tags)
def tag(self, tags: Iterable[str]) -> "Node":
"""Create a new ``Node`` which is an exact copy of the current one,
but with more tags added to it.
Args:
tags: The tags to be added to the new node.
Returns:
A copy of the current ``Node`` object with the tags added.
"""
return Node(
self._func,
self._inputs,
self._outputs,
name=self._name,
tags=set(self._tags) | set(tags),
decorators=self._decorators,
)
@property
def name(self) -> str:
"""Node's name.
Returns:
Node's name if provided or the name of its function.
"""
return self._name or str(self)
@property
def short_name(self) -> str:
"""Node's name.
Returns:
Returns a short user-friendly name that is not guaranteed to be unique.
"""
if self._name:
return self._name
return self._func_name.replace("_", " ").title()
@property
def inputs(self) -> List[str]:
"""Return node inputs as a list, in the order required to bind them properly to
the node's function. If the node's function contains ``kwargs``, then ``kwarg`` inputs
are sorted alphabetically (for python 3.5 deterministic behavior).
Returns:
Node input names as a list.
"""
if isinstance(self._inputs, dict):
return _dict_inputs_to_list(self._func, self._inputs)
return _to_list(self._inputs)
@property
def outputs(self) -> List[str]:
"""Return node outputs as a list preserving the original order
if possible.
Returns:
Node output names as a list.
"""
return _to_list(self._outputs)
@property
def _decorated_func(self):
return reduce(lambda g, f: f(g), self._decorators, self._func)
def decorate(self, *decorators: Callable) -> "Node":
"""Create a new ``Node`` by applying the provided decorators to the
underlying function. If no decorators are passed, it will return a
new ``Node`` object, but with no changes to the function.
Args:
decorators: List of decorators to be applied on the node function.
Decorators will be applied from right to left.
Returns:
A new ``Node`` object with the decorators applied to the function.
Example:
::
>>>
>>> from functools import wraps
>>>
>>>
>>> def apply_f(func: Callable) -> Callable:
>>> @wraps(func)
>>> def with_f(*args, **kwargs):
>>> args = ["f({})".format(a) for a in args]
>>> return func(*args, **kwargs)
>>> return with_f
>>>
>>>
>>> def apply_g(func: Callable) -> Callable:
>>> @wraps(func)
>>> def with_g(*args, **kwargs):
>>> args = ["g({})".format(a) for a in args]
>>> return func(*args, **kwargs)
>>> return with_g
>>>
>>>
>>> def apply_h(func: Callable) -> Callable:
>>> @wraps(func)
>>> def with_h(*args, **kwargs):
>>> args = ["h({})".format(a) for a in args]
>>> return func(*args, **kwargs)
>>> return with_h
>>>
>>>
>>> def apply_fg(func: Callable) -> Callable:
>>> @wraps(func)
>>> def with_fg(*args, **kwargs):
>>> args = ["fg({})".format(a) for a in args]
>>> return func(*args, **kwargs)
>>> return with_fg
>>>
>>>
>>> def identity(value):
>>> return value
>>>
>>>
>>> # using it as a regular python decorator
>>> @apply_f
>>> def decorated_identity(value):
>>> return value
>>>
>>>
>>> # wrapping the node function
>>> old_node = node(apply_g(decorated_identity), 'input', 'output',
>>> name='node')
>>> # using the .decorate() method to apply multiple decorators
>>> new_node = old_node.decorate(apply_h, apply_fg)
>>> result = new_node.run(dict(input=1))
>>>
>>> assert old_node.name == new_node.name
>>> assert "output" in result
>>> assert result['output'] == "f(g(fg(h(1))))"
"""
return Node(
self._func,
self._inputs,
self._outputs,
name=self._name,
tags=self.tags,
decorators=self._decorators + list(reversed(decorators)),
)
def run(self, inputs: Dict[str, Any] = None) -> Dict[str, Any]:
"""Run this node using the provided inputs and return its results
in a dictionary.
Args:
inputs: Dictionary of inputs as specified at the creation of
the node.
Raises:
ValueError: In the following cases:
a) The node function inputs are incompatible with the node
input definition.
Example 1: node definition input is a list of 2
DataFrames, whereas only 1 was provided or 2 different ones
were provided.
b) The node function outputs are incompatible with the node
output definition.
Example 1: node function definition is a dictionary,
whereas function returns a list.
Example 2: node definition output is a list of 5
strings, whereas the function returns a list of 4 objects.
Exception: Any exception thrown during execution of the node.
Returns:
All produced node outputs are returned in a dictionary, where the
keys are defined by the node outputs.
"""
self._logger.info("Running node: %s", str(self))
outputs = None
if not (inputs is None or isinstance(inputs, dict)):
raise ValueError(
"Node.run() expects a dictionary or None, "
"but got {} instead".format(type(inputs))
)
try:
inputs = dict() if inputs is None else inputs
if not self._inputs:
outputs = self._run_with_no_inputs(inputs)
elif isinstance(self._inputs, str):
outputs = self._run_with_one_input(inputs, self._inputs)
elif isinstance(self._inputs, list):
outputs = self._run_with_list(inputs, self._inputs)
elif isinstance(self._inputs, dict):
outputs = self._run_with_dict(inputs, self._inputs)
return self._outputs_to_dictionary(outputs)
# purposely catch all exceptions
except Exception as exc:
self._logger.error("Node `%s` failed with error: \n%s", str(self), str(exc))
raise exc
def _run_with_no_inputs(self, inputs: Dict[str, Any]):
if inputs:
raise ValueError(
"Node {} expected no inputs, "
"but got the following {} input(s) instead: {}".format(
str(self), len(inputs), list(sorted(inputs.keys()))
)
)
return self._decorated_func()
def _run_with_one_input(self, inputs: Dict[str, Any], node_input: str):
if len(inputs) != 1 or node_input not in inputs:
raise ValueError(
"Node {} expected one input named '{}', "
"but got the following {} input(s) instead: {}".format(
str(self), node_input, len(inputs), list(sorted(inputs.keys()))
)
)
return self._decorated_func(inputs[node_input])
def _run_with_list(self, inputs: Dict[str, Any], node_inputs: List[str]):
all_available = set(node_inputs).issubset(inputs.keys())
if len(node_inputs) != len(inputs) or not all_available:
# This can be split in future into two cases, one successful
raise ValueError(
"Node {} expected {} input(s) {}, "
"but got the following {} input(s) instead: {}.".format(
str(self),
len(node_inputs),
node_inputs,
len(inputs),
list(sorted(inputs.keys())),
)
)
# Ensure the function gets the inputs in the correct order
return self._decorated_func(*[inputs[item] for item in node_inputs])
def _run_with_dict(self, inputs: Dict[str, Any], node_inputs: Dict[str, str]):
all_available = set(node_inputs.values()).issubset(inputs.keys())
if len(set(node_inputs.values())) != len(inputs) or not all_available:
# This can be split in future into two cases, one successful
raise ValueError(
"Node {} expected {} input(s) {}, "
"but got the following {} input(s) instead: {}.".format(
str(self),
len(set(node_inputs.values())),
list(sorted(set(node_inputs.values()))),
len(inputs),
list(sorted(inputs.keys())),
)
)
kwargs = {arg: inputs[alias] for arg, alias in node_inputs.items()}
return self._decorated_func(**kwargs)
def _outputs_to_dictionary(self, outputs):
def _from_dict():
if set(self._outputs.keys()) != set(outputs.keys()):
raise ValueError(
"Failed to save outputs of node {}.\n"
"The node's output keys {} do not "
"match with the returned output's keys {}.".format(
str(self), set(outputs.keys()), set(self._outputs.keys())
)
)
return {name: outputs[key] for key, name in self._outputs.items()}
def _from_list():
if not isinstance(outputs, (list, tuple)):
raise ValueError(
"Failed to save outputs of node {}.\n"
"The node definition contains a list of "
"outputs {}, whereas the node function "
"returned a `{}`.".format(
str(self), self._outputs, type(outputs).__name__
)
)
if len(outputs) != len(self._outputs):
raise ValueError(
"Failed to save outputs of node {}.\n"
"The node function returned {} output(s), "
"whereas the node definition contains {} "
"output(s).".format(str(self), len(outputs), len(self._outputs))
)
return dict(zip(self._outputs, outputs))
if isinstance(self._outputs, dict) and not isinstance(outputs, dict):
raise ValueError(
"Failed to save outputs of node {}.\n"
"The node output is a dictionary, whereas the "
"function output is not.".format(str(self))
)
if self._outputs is None:
return {}
if isinstance(self._outputs, str):
return {self._outputs: outputs}
if isinstance(self._outputs, dict):
return _from_dict()
return _from_list()
def _validate_inputs(self, func, inputs):
# inspect does not support built-in Python functions written in C.
# Thus we only validate func if it is not built-in.
if not inspect.isbuiltin(func):
args, kwargs = self._process_inputs_for_bind(inputs)
try:
inspect.signature(func, follow_wrapped=False).bind(*args, **kwargs)
except Exception as exc:
func_args = inspect.signature(
func, follow_wrapped=False
).parameters.keys()
raise TypeError(
"Inputs of function expected {}, but got {}".format(
str(list(func_args)), str(inputs)
)
) from exc
def _validate_unique_outputs(self):
diff = Counter(self.outputs) - Counter(set(self.outputs))
if diff:
raise ValueError(
"Failed to create node {} due to duplicate"
" output(s) {}.\nNode outputs must be unique.".format(
str(self), set(diff.keys())
)
)
def _validate_inputs_dif_than_outputs(self):
common_in_out = set(self.inputs).intersection(set(self.outputs))
if common_in_out:
raise ValueError(
"Failed to create node {}.\n"
"A node cannot have the same inputs and outputs: "
"{}".format(str(self), common_in_out)
)
@staticmethod
def _process_inputs_for_bind(inputs: Union[None, str, List[str], Dict[str, str]]):
# Safeguard that we do not mutate list inputs
inputs = copy.copy(inputs)
args = [] # type: List[str]
kwargs = {} # type: Dict[str, str]
if isinstance(inputs, str):
args = [inputs]
elif isinstance(inputs, list):
args = inputs
elif isinstance(inputs, dict):
kwargs = inputs
return args, kwargs
def _node_error_message(msg) -> str:
return (
"Invalid Node definition: {}\n"
"Format should be: node(function, inputs, outputs)"
).format(msg)
def node( # pylint: disable=missing-type-doc
func: Callable,
inputs: Union[None, str, List[str], Dict[str, str]],
outputs: Union[None, str, List[str], Dict[str, str]],
*,
name: str = None,
tags: Iterable[str] = None
) -> Node:
"""Create a node in the pipeline by providing a function to be called
along with variable names for inputs and/or outputs.
Args:
func: A function that corresponds to the node logic. The function
should have at least one input or output.
inputs: The name or the list of the names of variables used as inputs
to the function. The number of names should match the number of
arguments in the definition of the provided function. When
Dict[str, str] is provided, variable names will be mapped to
function argument names.
outputs: The name or the list of the names of variables used as outputs
to the function. The number of names should match the number of
outputs returned by the provided function. When Dict[str, str]
is provided, variable names will be mapped to the named outputs the
function returns.
name: Optional node name to be used when displaying the node in logs or
any other visualisations.
tags: Optional set of tags to be applied to the node.
Returns:
A Node object with mapped inputs, outputs and function.
Example:
::
>>> import pandas as pd
>>> import numpy as np
>>>
>>> def clean_data(cars: pd.DataFrame,
>>> boats: pd.DataFrame) -> Dict[str, pd.DataFrame]:
>>> return dict(cars_df=cars.dropna(), boats_df=boats.dropna())
>>>
>>> def halve_dataframe(data: pd.DataFrame) -> List[pd.DataFrame]:
>>> return np.array_split(data, 2)
>>>
>>> nodes = [
>>> node(clean_data,
>>> inputs=['cars2017', 'boats2017'],
>>> outputs=dict(cars_df='clean_cars2017',
>>> boats_df='clean_boats2017')),
>>> node(halve_dataframe,
>>> 'clean_cars2017',
>>> ['train_cars2017', 'test_cars2017']),
>>> node(halve_dataframe,
>>> dict(data='clean_boats2017'),
>>> ['train_boats2017', 'test_boats2017'])
>>> ]
"""
return Node(func, inputs, outputs, name=name, tags=tags)
def _dict_inputs_to_list(func: Callable[[Any], Any], inputs: Dict[str, str]):
"""Convert a dict representation of the node inputs to a list , ensuring
the appropriate order for binding them to the node's function.
"""
sig = inspect.signature(func).bind(**inputs)
# for deterministic behavior in python 3.5, sort kwargs inputs alphabetically
return list(sig.args) + sorted(sig.kwargs.values())
def _to_list(element: Union[None, str, List[str], Dict[str, str]]) -> List:
"""Make a list out of node inputs/outputs.
Returns:
List[str]: Node input/output names as a list to standardise.
"""
if element is None:
return list()
if isinstance(element, str):
return [element]
if isinstance(element, dict):
return sorted(element.values())
return element