-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdecorators.py
105 lines (78 loc) · 2.79 KB
/
decorators.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import functools
import inspect
from typing import List
from flex.model import FlexModel
def __inspect_arguments(func):
signature = inspect.signature(func)
try:
assert len(signature.parameters) >= 1
except AssertionError as er:
raise AssertionError(
f"The decorated function is expected to have at least one argument. {er}"
) from er
def init_server_model(func):
__inspect_arguments(func=func)
@functools.wraps(func)
def _init_server_model_(server_flex_model: FlexModel, _, *args, **kwargs):
server_flex_model.update(func(*args, **kwargs))
return _init_server_model_
def deploy_server_model(func):
__inspect_arguments(func=func)
@functools.wraps(func)
def _deploy_model_(
server_flex_model: FlexModel,
clients_flex_models: List[FlexModel],
*args,
**kwargs,
):
for k in clients_flex_models:
# Reminder, it is not possible to make assignements here
clients_flex_models[k].update(func(server_flex_model, *args, **kwargs))
return _deploy_model_
def collect_clients_weights(func):
__inspect_arguments(func=func)
@functools.wraps(func)
def _collect_weights_(
aggregator_flex_model: FlexModel,
clients_flex_models: List[FlexModel],
*args,
**kwargs,
):
if "weights" not in aggregator_flex_model:
aggregator_flex_model["weights"] = []
for k in clients_flex_models:
client_weights = func(clients_flex_models[k], *args, **kwargs)
aggregator_flex_model["weights"].append(client_weights)
return _collect_weights_
def aggregate_weights(func):
__inspect_arguments(func=func)
@functools.wraps(func)
def _aggregate_weights_(aggregator_flex_model: FlexModel, _, *args, **kwargs):
aggregator_flex_model["aggregated_weights"] = func(
aggregator_flex_model["weights"], *args, **kwargs
)
aggregator_flex_model["weights"] = []
return _aggregate_weights_
def set_aggregated_weights(func):
__inspect_arguments(func=func)
@functools.wraps(func)
def _deploy_aggregated_weights_(
aggregator_flex_model: FlexModel,
servers_flex_models: FlexModel,
*args,
**kwargs,
):
for k in servers_flex_models:
func(
servers_flex_models[k],
aggregator_flex_model["aggregated_weights"],
*args,
**kwargs,
)
return _deploy_aggregated_weights_
def evaluate_server_model(func):
__inspect_arguments(func=func)
@functools.wraps(func)
def _evaluate_server_model_(server_flex_model: FlexModel, _, *args, **kwargs):
return func(server_flex_model, *args, **kwargs)
return _evaluate_server_model_