Skip to content

Commit

Permalink
Merge pull request #93 from FLEXible-FL/round-properly-number-of-node…
Browse files Browse the repository at this point in the history
…s-in-filter

Fix rounding in fiter
  • Loading branch information
AlArgente authored Jun 27, 2023
2 parents b89851e + 6af79bf commit f1946a2
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 2 deletions.
4 changes: 2 additions & 2 deletions flex/pool/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,8 +144,8 @@ def filter(self, func: Callable = None, node_dropout: float = 0.0, **kwargs):
"""
node_dropout = max(0, node_dropout)
node_dropout = max(1 - min(node_dropout, 1), 0)
node_dropout = int(len(self._actors) * node_dropout)
selected_nodes = random.sample(list(self._actors.keys()), node_dropout)
num_nodes = round(len(self._actors) * node_dropout)
selected_nodes = random.sample(list(self._actors.keys()), num_nodes)
new_actors = FlexActors()
new_data = FedDataset()
new_models = {}
Expand Down
13 changes: 13 additions & 0 deletions tests/pool/test_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,19 @@ def test_filter(self):
for _, actor_role in new_p._actors.items()
)

def test_filter_dropout(self):
iris = load_iris()
tmp = Dataset(X_data=iris.data, y_data=iris.target)
self._iris_many_clients = FedDataDistribution.iid_distribution(
tmp, n_clients=100
)
p = FlexPool.client_server_architecture(self._iris, lambda *args: None)
pool_size = len(p)
assert all(
len(p.filter(node_dropout=1 - (i / pool_size))) == i
for i in range(pool_size)
)

def test_client_property(self):
p = FlexPool.p2p_architecture(self._fld, init_func=lambda *args: None)
new_p = p.clients
Expand Down

0 comments on commit f1946a2

Please sign in to comment.