-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmid_averaging.py
131 lines (101 loc) · 4.13 KB
/
mid_averaging.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
"""
Script for mid averaging
This module is used to find an average for the weight updates received from
end client devices.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import argparse
import os
import numpy as np
from keras.models import load_model
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
def mid_averaging(updates, model_path, ckpt_path):
"""
Uses the weight updates received from end client devices, checkpoint weights
after the previous FL round and the predefined model architecture to find the
average weights and saves to a file
"""
print("Model path: ", model_path)
print("Checkpoint path: ", ckpt_path)
print("Updates array: ", updates)
total_num_batches = 0
# Load model architecture
sum_weight_updates = load_model(model_path)
device_weight_updates = load_model(model_path)
# Calculate sum of weight updates from all devices
for device_index in range(len(updates)):
n, weight_updates_path = updates[device_index]
total_num_batches += n
if device_index == 0:
sum_weight_updates.load_weights(weight_updates_path)
else:
# Load device weight updates checkpoint
device_weight_updates.load_weights(weight_updates_path)
# Add weight updates from device to prefix sum
for layer_index in range(len(sum_weight_updates.layers)):
# Old sum of weight updates
old_sum_weight_updates_values = sum_weight_updates.layers[
layer_index
].get_weights()
# Device weight updates
device_weight_updates_values = device_weight_updates.layers[
layer_index
].get_weights()
# Weight updates calculation
sum_weight_updates.layers[layer_index].set_weights(
np.asarray(old_sum_weight_updates_values)
+ np.asarray(device_weight_updates_values),
)
# print("old weights: ", old_layer_weights)
# print("new weights: ", new_layer_weights)
# print("update weights: ", update_weights.layers[i].get_weights())
# Add average of weight updates to checkpoint
# Load model and checkpoints
model = load_model(model_path)
for layer_index in range(len(model.layers)):
# weight sum updates values
sum_weight_updates_values = sum_weight_updates.layers[layer_index].get_weights()
model.layers[layer_index].set_weights(np.asarray(sum_weight_updates_values))
# Save updated model checkpoints
model.save_weights(ckpt_path)
print("New checkpoint saved at ", ckpt_path)
def main():
"""
The script is run from the fl-selector server after all weight updates from end
client devices are received
"""
# define Arguments
parser = argparse.ArgumentParser(description="Perform Federated Averaging")
parser.add_argument("--cf", "--ckpt-file-path", required=True, nargs=1)
parser.add_argument("--mf", "--model-file-path", required=True, nargs=1)
parser.add_argument("--u", "--updates", required=True, nargs="*")
# params for federated averaging
model_path = ""
ckpt_path = ""
updates = []
# parse arguments
args = parser.parse_args()
for arg in vars(args):
# print(arg, getattr(args, arg))
if arg == "mf":
model_path = getattr(args, arg)[0]
elif arg == "cf":
ckpt_path = getattr(args, arg)[0]
else:
update_args = getattr(args, arg)
print(update_args)
print("Len: ", len(update_args))
for i in range(0, len(update_args), 2):
print("index: ", i)
# print(update_args[i])
# print(update_args[i + 1])
n = int(update_args[i])
path = update_args[i + 1]
updates.append((n, path))
# run federated averaging
mid_averaging(updates, model_path, ckpt_path)
if __name__ == "__main__":
main()