-
Notifications
You must be signed in to change notification settings - Fork 9
/
clients.py
84 lines (67 loc) · 3.56 KB
/
clients.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import numpy as np
import torch
from torch.utils.data import TensorDataset
from torch.utils.data import DataLoader
from getData import GetDataSet
class client(object):
def __init__(self, trainDataSet, dev):
self.train_ds = trainDataSet
self.dev = dev
self.train_dl = None
self.local_parameters = None
def localUpdate(self, localEpoch, localBatchSize, Net, lossFun, opti, global_parameters):
Net.load_state_dict(global_parameters, strict=True)
self.train_dl = DataLoader(self.train_ds, batch_size=localBatchSize, shuffle=True)
for epoch in range(localEpoch):
for data, label in self.train_dl:
data, label = data.to(self.dev), label.to(self.dev)
preds = Net(data)
loss = lossFun(preds, label)
loss.backward()
opti.step()
opti.zero_grad()
return Net.state_dict()
def local_val(self):
pass
class ClientsGroup(object):
def __init__(self, dataSetName, isIID, numOfClients, dev):
# 数据集名称,是否独立同分布,参与方个数,GPU or CPU
self.data_set_name = dataSetName
self.is_iid = isIID
self.num_of_clients = numOfClients
self.dev = dev
# clients_set格式为'client{i}' : Client(i)
self.clients_set = {}
self.test_data_loader = None
self.dataSetBalanceAllocation()
# 初始化CLientGroup内容
def dataSetBalanceAllocation(self):
mnistDataSet = GetDataSet(self.data_set_name, self.is_iid)
# 测试集数据和标签(标签由向量转换为整型,如[0,0,1]->2)
test_data = torch.tensor(mnistDataSet.test_data)
test_label = torch.argmax(torch.tensor(mnistDataSet.test_label), dim=1)
self.test_data_loader = DataLoader(TensorDataset(test_data, test_label), batch_size=100, shuffle=False)
# 训练集数据和标签
train_data = mnistDataSet.train_data
train_label = mnistDataSet.train_label
self.train_data_loader = DataLoader(TensorDataset(torch.tensor(train_data), torch.argmax(torch.tensor(train_label), dim=1)), batch_size=100, shuffle=False)
shard_size = mnistDataSet.train_data_size // self.num_of_clients // 2
shards_id = np.random.permutation(mnistDataSet.train_data_size // shard_size)
# 初始化num_of_clients,为每个参与方分配数据
for i in range(self.num_of_clients):
shards_id1 = shards_id[i * 2]
shards_id2 = shards_id[i * 2 + 1]
data_shards1 = train_data[shards_id1 * shard_size: shards_id1 * shard_size + shard_size]
data_shards2 = train_data[shards_id2 * shard_size: shards_id2 * shard_size + shard_size]
label_shards1 = train_label[shards_id1 * shard_size: shards_id1 * shard_size + shard_size]
label_shards2 = train_label[shards_id2 * shard_size: shards_id2 * shard_size + shard_size]
local_data, local_label = np.vstack((data_shards1, data_shards2)), np.vstack((label_shards1, label_shards2))
local_label = np.argmax(local_label, axis=1)
# 生成client,训练测试数据由np转为tensor
someone = client(TensorDataset(torch.tensor(local_data), torch.tensor(local_label)), self.dev)
self.clients_set['client{}'.format(i)] = someone
# 测试ClientGroup
# if __name__=="__main__":
# MyClients = ClientsGroup('mnist', True, 100, 1)
# print(MyClients.clients_set['client10'].train_ds[0:100])
# print(MyClients.clients_set['client11'].train_ds[400:500])