From 54fbadd7c9b8256dcbe1d8175ee7583bb43f0bcd Mon Sep 17 00:00:00 2001 From: xehartnort Date: Thu, 22 Jun 2023 13:41:33 +0200 Subject: [PATCH 1/2] add a better way to share data among clients per each class --- flex/data/fed_data_distribution.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/flex/data/fed_data_distribution.py b/flex/data/fed_data_distribution.py index aa6ada61..01c7e957 100644 --- a/flex/data/fed_data_distribution.py +++ b/flex/data/fed_data_distribution.py @@ -355,11 +355,18 @@ def __configure_weights_per_class( sorted_classes = np.sort(np.unique(data.y_data)) assigned_classes = [] if isinstance(config.classes_per_client, int): + histogram = np.zeros_like(sorted_classes) for _ in range(config.n_clients): - n = rng.choice( - sorted_classes, size=config.classes_per_client, replace=False - ) - assigned_classes.append(n) + individual_assigned_classes = [] + for _ in range(config.classes_per_client): + most_frequent = np.max(histogram) + available_classes = sorted_classes[histogram < most_frequent] + if len(available_classes) == 0: + available_classes = sorted_classes + n = rng.choice(available_classes, size=1, replace=False) + histogram[n] = histogram[n] + 1 + individual_assigned_classes.append(n) + assigned_classes.append(individual_assigned_classes) config.classes_per_client = assigned_classes elif isinstance(config.classes_per_client, tuple): num_classes_per_client = rng.integers( From 35118f38ba9d0dcf54670e7e01252dd9a80cfa3e Mon Sep 17 00:00:00 2001 From: xehartnort Date: Thu, 22 Jun 2023 14:29:58 +0200 Subject: [PATCH 2/2] remove dependency with label type --- flex/data/fed_data_distribution.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/flex/data/fed_data_distribution.py b/flex/data/fed_data_distribution.py index 01c7e957..d8c5fe6f 100644 --- a/flex/data/fed_data_distribution.py +++ b/flex/data/fed_data_distribution.py @@ -360,12 +360,13 @@ def __configure_weights_per_class( individual_assigned_classes = [] for _ in range(config.classes_per_client): most_frequent = np.max(histogram) - available_classes = sorted_classes[histogram < most_frequent] - if len(available_classes) == 0: - available_classes = sorted_classes - n = rng.choice(available_classes, size=1, replace=False) - histogram[n] = histogram[n] + 1 - individual_assigned_classes.append(n) + available_classes_indexes = np.arange(len(sorted_classes)) + tmp_available_indexes = histogram < most_frequent + if sum(tmp_available_indexes) != 0: + available_classes_indexes = available_classes_indexes[tmp_available_indexes] + indx = rng.choice(available_classes_indexes, size=1, replace=False) + histogram[indx] = histogram[indx] + 1 + individual_assigned_classes.append(sorted_classes[indx]) assigned_classes.append(individual_assigned_classes) config.classes_per_client = assigned_classes elif isinstance(config.classes_per_client, tuple):