Skip to content

Commit

Permalink
Merge pull request #98 from FLEXible-FL/FLEX-114-cambiar-metodo-filte…
Browse files Browse the repository at this point in the history
…r-a-metodo-select
  • Loading branch information
AlArgente authored Jun 28, 2023
2 parents e8c59fe + fc400e4 commit fb026d1
Show file tree
Hide file tree
Showing 10 changed files with 132 additions and 214 deletions.
68 changes: 38 additions & 30 deletions flex/pool/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import functools
import random
from typing import Callable, Hashable
from typing import Callable, Hashable, Union

from flex.actors.actors import FlexActors, FlexRole, FlexRoleManager
from flex.data import FedDataset
Expand Down Expand Up @@ -126,40 +126,48 @@ def map(self, func: Callable, dst_pool: FlexPool = None, **kwargs):
if all(ele is not None for ele in res):
return res

def filter(self, func: Callable = None, node_dropout: float = 0.0, **kwargs):
"""Function that filter the PoolManager by actors given a function.
The function must return True/False, and it recieves the args and kwargs arguments
for its internal uses. Also, the function recieves an actor_id and an actor_role.
The actor_id is a string, and the actor_role is a FlexRole.
def select(self, criteria: Union[int, Callable], *criteria_args, **criteria_kwargs):
"""Function that returns a subset of a FlexPool meeting a certain criteria.
If criteria is an integer, a subset of the available nodes of size criteria is
randomly sampled. If criteria is a function, then we select those nodes where
the function returns True values. Note that, the function must have at least
two arguments, a client id and the roles associated to such client id.
The actor_id is a string, and the actor_role is a FlexRole object.
Note: This function doesn't send a copy of the original pool, it sends a reference.
Changes made on the new pool may affect the original pool.
Changes made on the returned pool affect the original pool.
Args:
func (Callable): Function to filter the pool by. The function must return True/False.
node_dropout (float): Percentage of node to drop from the training phase. This param
must be a value in the range [0, 1]. If the node_dropout > 1, it will return all the
pool without any changes. For negative values are ignored.
criteria (int, Callable): if a function is provided, then it must return
True/False values for each pair of client_id, client_roles. Additional arguments
required for the function are passed in criteria_args and criteria_kwargs.
Otherwise, criteria is interpreted as number of nodes to randomly sample from the pool.
criteria_args: additional args required for the criteria function. Otherwise ignored.
criteria_kwargs: additional keyword args required for the criteria function. Otherwise ignored.
Returns:
FlexPool: New filtered pool.
FlexPool: a pool that contains the nodes that meet the criteria.
"""
node_dropout = max(0, node_dropout)
node_dropout = max(1 - min(node_dropout, 1), 0)
num_nodes = round(len(self._actors) * node_dropout)
selected_nodes = random.sample(list(self._actors.keys()), num_nodes)
new_actors = FlexActors()
new_data = FedDataset()
new_models = {}
for actor_id in selected_nodes:
cond = (
True
if func is None
else func(actor_id, self._actors[actor_id], **kwargs)
)
if cond:
new_actors[actor_id] = self._actors[actor_id]
new_models[actor_id] = self._models[actor_id]
if actor_id in self._data:
new_data[actor_id] = self._data[actor_id]
available_nodes = self._actors.keys()
if callable(criteria):
selected_keys = [
actor_id
for actor_id in available_nodes
if criteria(
actor_id, self._actors[actor_id], *criteria_args, **criteria_kwargs
)
]
else:
num_nodes = criteria
selected_keys = random.sample(available_nodes, num_nodes)

for actor_id in selected_keys:
new_actors[actor_id] = self._actors[actor_id]
new_models[actor_id] = self._models[actor_id]
if actor_id in self._data:
new_data[actor_id] = self._data[actor_id]

return FlexPool(
flex_actors=new_actors, flex_data=new_data, flex_models=new_models
Expand All @@ -179,7 +187,7 @@ def clients(self):
Returns:
FlexPool: Pool containing all the clients from a pool
"""
return self.filter(lambda a, b: FlexRoleManager.is_client(b))
return self.select(lambda a, b: FlexRoleManager.is_client(b))

@functools.cached_property
def aggregators(self):
Expand All @@ -188,7 +196,7 @@ def aggregators(self):
Returns:
FlexPool: Pool containing all the aggregators from a pool
"""
return self.filter(lambda a, b: FlexRoleManager.is_aggregator(b))
return self.select(lambda a, b: FlexRoleManager.is_aggregator(b))

@functools.cached_property
def servers(self):
Expand All @@ -197,7 +205,7 @@ def servers(self):
Returns:
FlexPool: Pool containing all the servers from a pool
"""
return self.filter(lambda a, b: FlexRoleManager.is_server(b))
return self.select(lambda a, b: FlexRoleManager.is_server(b))

def validate(self): # sourcery skip: de-morgan
"""Function that checks whether the object is correct or not."""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,10 +109,9 @@
"metadata": {},
"outputs": [],
"source": [
"#Filter clients\n",
"#Select clients\n",
"clients_per_round=20\n",
"node_dropout = 1-(clients_per_round/len(clients))\n",
"selected_clients_pool = clients.filter(node_dropout=node_dropout)\n",
"selected_clients_pool = clients.select(clients_per_round)\n",
"selected_clients = selected_clients_pool.clients\n",
"\n",
"print(f\"Server node is indentified by key \\\"{servers.actor_ids[0]}\\\"\")\n",
Expand Down Expand Up @@ -198,7 +197,7 @@
"\n",
"@collect_clients_weights\n",
"def get_clients_weights(client_flex_model: FlexModel):\n",
" return client_flex_model.model.get_weights()\n",
" return client_flex_model[\"model\"].get_weights()\n",
"\n",
"aggregators.map(get_clients_weights, selected_clients)"
]
Expand Down Expand Up @@ -246,7 +245,7 @@
"\n",
"@set_aggregated_weights\n",
"def set_agreggated_weights_to_server(server_flex_model: FlexModel, aggregated_weights):\n",
" server_flex_model.model.set_weights(aggregated_weights) \n",
" server_flex_model[\"model\"].set_weights(aggregated_weights) \n",
"\n",
"aggregators.map(set_agreggated_weights_to_server, servers)"
]
Expand All @@ -269,7 +268,7 @@
"\n",
"@evaluate_server_model\n",
"def evaluate_global_model(server_flex_model: FlexModel, test_data=None):\n",
" return server_flex_model.model.evaluate(test_data.X_data, test_data.y_data, verbose=False)\n",
" return server_flex_model[\"model\"].evaluate(test_data.X_data, test_data.y_data, verbose=False)\n",
"\n",
"metrics = servers.map(evaluate_global_model, test_data=test_data)\n",
"print(metrics[0])"
Expand All @@ -295,8 +294,7 @@
" pool = FlexPool.client_server_architecture(fed_dataset=flex_dataset, init_func=build_server_model)\n",
" for i in range(n_rounds):\n",
" print(f\"\\nRunning round: {i+1} of {n_rounds+1}\")\n",
" node_dropout = 1-(clients_per_round/len(pool.clients))\n",
" selected_clients_pool = pool.clients.filter(node_dropout=node_dropout)\n",
" selected_clients_pool = pool.clients.select(clients_per_round)\n",
" selected_clients = selected_clients_pool.clients\n",
" print(f\"Selected clients for this round: {len(selected_clients)}\")\n",
" # Deploy the server model to the selected clients\n",
Expand All @@ -319,6 +317,13 @@
"source": [
"train_n_rounds(5)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,10 +94,9 @@
"metadata": {},
"outputs": [],
"source": [
"#Filter clients\n",
"#Select clients\n",
"clients_per_round=20\n",
"node_dropout = 1-(clients_per_round/len(clients))\n",
"selected_clients_pool = clients.filter(node_dropout=node_dropout)\n",
"selected_clients_pool = clients.select(clients_per_round)\n",
"selected_clients = selected_clients_pool.clients\n",
"\n",
"print(f\"Server node is indentified by key \\\"{server.actor_ids[0]}\\\"\")\n",
Expand Down Expand Up @@ -205,8 +204,7 @@
" pool = FlexPool.client_server_architecture(fed_dataset=flex_dataset, init_func=init_server_model_tf, model=get_model())\n",
" for i in range(n_rounds):\n",
" print(f\"\\nRunning round: {i+1} of {n_rounds+1}\")\n",
" node_dropout = 1-(clients_per_round/len(pool.clients))\n",
" selected_clients_pool = pool.clients.filter(node_dropout=node_dropout)\n",
" selected_clients_pool = pool.clients.select(clients_per_round)\n",
" selected_clients = selected_clients_pool.clients\n",
" print(f\"Selected clients for this round: {len(selected_clients)}\")\n",
" # Deploy the server model to the selected clients\n",
Expand Down
16 changes: 8 additions & 8 deletions notebooks/Federating a clustering model.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -150,8 +150,8 @@
"@init_server_model\n",
"def build_server_model(n_clusters=10):\n",
" flex_model = FlexModel()\n",
" flex_model.model = KMeans(n_clusters=n_clusters)\n",
" flex_model.model._n_threads = 1\n",
" flex_model[\"model\"] = KMeans(n_clusters=n_clusters)\n",
" flex_model[\"model\"]._n_threads = 1\n",
" return flex_model\n",
"\n",
"\n",
Expand Down Expand Up @@ -206,7 +206,7 @@
"outputs": [],
"source": [
"def train(client_flex_model: FlexModel, client_data: Dataset):\n",
" client_flex_model.model.fit(client_data.X_data)\n",
" client_flex_model[\"model\"].fit(client_data.X_data)\n",
"\n",
"clients.map(train)"
]
Expand All @@ -229,7 +229,7 @@
"\n",
"@collect_clients_weights\n",
"def get_clients_weights(client_flex_model: FlexModel):\n",
" return client_flex_model.model.cluster_centers_\n",
" return client_flex_model[\"model\"].cluster_centers_\n",
"\n",
"aggregators.map(get_clients_weights, clients)\n",
"# The same as:\n",
Expand Down Expand Up @@ -287,7 +287,7 @@
"\n",
"@set_aggregated_weights\n",
"def set_agreggated_weights_to_server(server_flex_model: FlexModel, aggregated_weights):\n",
" server_flex_model.model.cluster_centers_ = aggregated_weights\n",
" server_flex_model[\"model\"].cluster_centers_ = aggregated_weights\n",
"\n",
"# Set aggregated weights in the server model\n",
"aggregators.map(set_agreggated_weights_to_server, servers)\n",
Expand All @@ -313,8 +313,8 @@
"\n",
"@evaluate_server_model\n",
"def evaluate_global_model(server_flex_model: FlexModel, test_data=None):\n",
" preds = server_flex_model.model.predict(test_data.X_data)\n",
" plot_k_means(server_flex_model.model, test_data.X_data, title=\"K-means using federated data\")\n",
" preds = server_flex_model[\"model\"].predict(test_data.X_data)\n",
" plot_k_means(server_flex_model[\"model\"], test_data.X_data, title=\"K-means using federated data\")\n",
" score_model(test_data.y_data, preds)\n",
"\n",
"servers.map(evaluate_global_model, test_data=test_iris)"
Expand Down Expand Up @@ -386,7 +386,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.12"
"version": "3.9.16"
},
"vscode": {
"interpreter": {
Expand Down
163 changes: 28 additions & 135 deletions notebooks/Federating a linear regression model.ipynb

Large diffs are not rendered by default.

Loading

0 comments on commit fb026d1

Please sign in to comment.