-
Notifications
You must be signed in to change notification settings - Fork 0
/
pool.py
323 lines (277 loc) · 13.6 KB
/
pool.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
"""
Copyright (C) 2024 Instituto Andaluz Interuniversitario en Ciencia de Datos e Inteligencia Computacional (DaSCI).
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published
by the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <https://www.gnu.org/licenses/>.
"""
from __future__ import annotations
import functools
import random
from typing import Callable, Hashable, Union
from flex.actors.actors import FlexActors, FlexRoleManager
from flex.actors.architectures import client_server_architecture, p2p_architecture
from flex.data import FedDataset
from flex.model.model import FlexModel
class FlexPool:
"""
Class that orchest the training phase of a federated learning experiment.
The FlexPool class is responsible for orchestating the nodes to train a
federated model.
This class represents a pool of actors and is in charge of checking the
communications between them during the process of training a federated model.
Note: At the moment this class only supports Horizontal Federated Learning,
but in the future it will cover Vertical Federated Learning and Transfer Learning,
so users can simulate all the experiments correctly.
Attributes
----------
- flex_data (FedDataset): The federated dataset prepared to be used.
- flex_actors (FlexActors): Actors with its roles.
- flex_models (defaultdict): A dictionary containing the each actor id,
and initialized to None. The model to train by each actor will be initialized
using the map function following the communication constraints.
We offer two class methods to create two architectures, client-server architecture
and p2p architecture. In the client-server architecture, every id from the
FedDataset is assigned to be a client, and we create a third-party actor,
supposed to be neutral, to orchestate the training. Meanwhile, in the p2p
architecture, each id from the FedDataset will be assigned to be client,
server and aggregator. In both cases, the method will create the actors
so the user will only have to apply the map function to train the model.
If the user wants to use a different architecture, she will need to create
the actors by using the FlexActors class. For example, we let the user create
a client-server architecture with multiple aggregators, to carry out the aggregation.
"""
def __init__(
self,
flex_data: FedDataset,
flex_actors: FlexActors,
flex_models: dict[Hashable, FlexModel] = None,
) -> None:
self._actors = flex_actors # Actors
self._data = flex_data # FedDataset
self._models = flex_models
if self._models is None:
self._models = {k: FlexModel() for k in self._actors}
for k in self._models:
self._models[k].actor_id = k
self.validate() # check if the provided arguments generate a valid object
@classmethod
def check_compatibility(cls, src_pool, dst_pool):
"""Method to check the compatibility between two different pools.
This method is used by the map function to check if the
function to apply from the source pool to the destination pool can be done.
Args:
src_pool (FlexPool): Source pool. Pool that will send the message.
dst_pool (FlexPool): Destination pool. Pool that will recieve the message.
Returns:
bool: True if pools are compatible. False in other case.
"""
return all(
FlexRoleManager.check_compatibility(src, dst)
for _, src in src_pool._actors.items()
for _, dst in dst_pool._actors.items()
)
def map(self, func: Callable, dst_pool: FlexPool = None, **kwargs):
r"""Method used to send messages from one pool to another. The pool using
this method is the source pool, and it will send a message, apply a function,
to the destination pool. If no destination pool is provided, then the function is applied
to the source pool. The pool sends a message in order to complete a round
in the Federated Learning (FL) paradigm, so, some examples of the messages
that will be used by the pools are:
- send_model: Aggregators send the model to the server when the aggregation is done.
- aggregation_step: Clients send the model to the aggregator so it can apply the
aggregation mechanism given.
- deploy_model: Server sends the global model to the clients once the weights has
been aggregated.
- init_model: Server sends the model to train during the learning phase, so the
clients can initialize it. This is a particular case from the deploy_model case.
Args:
-----
func (Callable): If dst_pool is None, then message is sent to the source (self). In this situation
the function func is called for each actor in the pool, providing actor's data and actor's model
as arguments in addition to \*args and \**kwargs. If dst_pool is not None, the message is sent from
the source pool (self) to the destination pool (dst_pool). The function func is called for each actor
in the pool, providing the model of the current actor in the source pool and all the models of the
actors in the destination pool.
dst_pool (FlexPool): Pool that will recieve the message from the source pool (self), it can be None.
Raises:
-------
ValueError: This method raises and error if the pools aren't allowed to comunicate
Returns:
--------
List[Any]: A list of the result of applying the function (func) from the source pool (self) to the
destination pool (dst_pool). If dst_pool is None, then the results come from the source pool. The
length of the returned values equals the number of actors in the source pool.
"""
if dst_pool is None:
res = [
func(self._models.get(i), self._data.get(i), **kwargs)
for i in self._actors
]
elif FlexPool.check_compatibility(self, dst_pool):
res = [
func(self._models.get(i), dst_pool._models, **kwargs)
for i in self._actors
]
else:
raise ValueError(
"Source and destination pools are not allowed to comunicate, ensure that their actors can communicate."
)
if all(ele is not None for ele in res):
return res
def select(self, criteria: Union[int, Callable], *criteria_args, **criteria_kwargs):
"""Function that returns a subset of a FlexPool meeting a certain criteria.
If criteria is an integer, a subset of the available nodes of size criteria is
randomly sampled. If criteria is a function, then we select those nodes where
the function returns True values. Note that, the function must have at least
two arguments, a node id and the roles associated to such node id.
The actor_id is a string, and the actor_role is a FlexRole object.
Note: This function doesn't send a copy of the original pool, it sends a reference.
Changes made on the returned pool affect the original pool.
Args:
-----
criteria (int, Callable): if a function is provided, then it must return
True/False values for each pair of node_id, node_roles. Additional arguments
required for the function are passed in criteria_args and criteria_kwargs.
Otherwise, criteria is interpreted as number of nodes to randomly sample from the pool.
criteria_args: additional args required for the criteria function. Otherwise ignored.
criteria_kwargs: additional keyword args required for the criteria function. Otherwise ignored.
Returns:
--------
FlexPool: a pool that contains the nodes that meet the criteria.
"""
new_actors = FlexActors()
new_data = FedDataset()
new_models = {}
available_nodes = list(self._actors.keys())
if callable(criteria):
selected_keys = [
actor_id
for actor_id in available_nodes
if criteria(
actor_id, self._actors[actor_id], *criteria_args, **criteria_kwargs
)
]
else:
num_nodes = criteria
selected_keys = random.sample(available_nodes, num_nodes)
for actor_id in selected_keys:
new_actors[actor_id] = self._actors[actor_id]
new_models[actor_id] = self._models[actor_id]
if actor_id in self._data:
new_data[actor_id] = self._data[actor_id]
return FlexPool(
flex_actors=new_actors, flex_data=new_data, flex_models=new_models
)
def __len__(self):
return len(self._actors)
@functools.cached_property
def actor_ids(self):
return list(self._actors.keys())
@functools.cached_property
def clients(self):
"""Property to get all the clients available in a pool.
Returns:
--------
FlexPool: Pool containing all the clients from a pool
"""
return self.select(lambda a, b: FlexRoleManager.is_client(b))
@functools.cached_property
def aggregators(self):
"""Property to get all the aggregator available in a pool.
Returns:
--------
FlexPool: Pool containing all the aggregators from a pool
"""
return self.select(lambda a, b: FlexRoleManager.is_aggregator(b))
@functools.cached_property
def servers(self):
"""Property to get all the servers available in a pool.
Returns:
--------
FlexPool: Pool containing all the servers from a pool
"""
return self.select(lambda a, b: FlexRoleManager.is_server(b))
def validate(self): # sourcery skip: de-morgan
"""Function that checks whether the object is correct or not."""
actors_ids = self._actors.keys()
data_ids = self._data.keys()
models_ids = self._models.keys()
for actor_id in actors_ids:
if (
FlexRoleManager.is_client(self._actors[actor_id])
and actor_id not in data_ids
):
raise ValueError(
"All node with client role must have data. Node with client role and id {actor_id} does not have any data."
)
if actor_id not in self._models:
raise ValueError(
f"All nodes must have a FlexModel object associated as a model, but {actor_id} does not."
)
flex_models_ids = {self._models[k].actor_id for k in self._models}
if not (
actors_ids >= data_ids
and actors_ids >= models_ids
and actors_ids >= flex_models_ids
): # noqa: E501
raise ValueError("Each node with data or model must have a role asigned")
@classmethod
def client_server_pool(
cls,
fed_dataset: FedDataset,
init_func: Callable,
server_id: str = "server",
**kwargs,
):
"""Method to create a client-server architeture for a FlexDataset given.
This functions is used when you have a FlexDataset and you want to start
the learning phase following a traditional client-server architecture.
This method will assing to each id from the FlexDataset the client-role,
and will create a new actor that will be the server-aggregator that will
orchestrate the learning phase.
Args:
-----
fed_dataset (FedDataset): Federated dataset used to train a model.
Returns:
--------
FlexPool: A FlexPool with the assigned roles for a client-server architecture.
"""
actors = client_server_architecture(fed_dataset.keys(), server_id)
new_arch = cls(
flex_data=fed_dataset,
flex_actors=actors,
flex_models=None,
)
new_arch.servers.map(init_func, **kwargs)
return new_arch
@classmethod
def p2p_pool(cls, fed_dataset: FedDataset, init_func: Callable, **kwargs):
"""Method to create a peer-to-peer (p2p) architecture for a FlexDataset given.
This method is used when you have a FlexDataset and you want to start the
learning phase following a p2p architecture.
This method will assing all roles (client-aggregator-server) to every id from
the FlexDataset, so each participant in the learning phase can act as client,
aggregator and server.
Args:
-----
fed_dataset (FedDataset): Federated dataset used to train a model.
Returns:
--------
FlexPool: A FlexPool with the assigned roles for a p2p architecture.
"""
new_arch = cls(
flex_data=fed_dataset,
flex_actors=p2p_architecture(fed_dataset),
flex_models=None,
)
new_arch.servers.map(init_func, **kwargs)
return new_arch
def __iter__(self):
yield from self._actors