diff --git a/flex/pool/pool.py b/flex/pool/pool.py index b87ab890..17590518 100644 --- a/flex/pool/pool.py +++ b/flex/pool/pool.py @@ -144,8 +144,8 @@ def filter(self, func: Callable = None, node_dropout: float = 0.0, **kwargs): """ node_dropout = max(0, node_dropout) node_dropout = max(1 - min(node_dropout, 1), 0) - node_dropout = int(len(self._actors) * node_dropout) - selected_nodes = random.sample(list(self._actors.keys()), node_dropout) + num_nodes = round(len(self._actors) * node_dropout) + selected_nodes = random.sample(list(self._actors.keys()), num_nodes) new_actors = FlexActors() new_data = FedDataset() new_models = {} diff --git a/tests/pool/test_pool.py b/tests/pool/test_pool.py index 5a7e70c0..b15f23d1 100644 --- a/tests/pool/test_pool.py +++ b/tests/pool/test_pool.py @@ -160,6 +160,19 @@ def test_filter(self): for _, actor_role in new_p._actors.items() ) + def test_filter_dropout(self): + iris = load_iris() + tmp = Dataset(X_data=iris.data, y_data=iris.target) + self._iris_many_clients = FedDataDistribution.iid_distribution( + tmp, n_clients=100 + ) + p = FlexPool.client_server_architecture(self._iris, lambda *args: None) + pool_size = len(p) + assert all( + len(p.filter(node_dropout=1 - (i / pool_size))) == i + for i in range(pool_size) + ) + def test_client_property(self): p = FlexPool.p2p_architecture(self._fld, init_func=lambda *args: None) new_p = p.clients