-
Notifications
You must be signed in to change notification settings - Fork 0
/
test_model_merge.py
61 lines (45 loc) · 1.51 KB
/
test_model_merge.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import torch
import torch.nn as nn
import torch.optim as optim
# Assume you have network A, B, and C
# For instance:
class NetworkA(nn.Module):
# Assume input size 10 and output size 20
def __init__(self):
super(NetworkA, self).__init__()
self.layer = nn.Linear(10, 20)
def forward(self, x):
return self.layer(x)
class NetworkB(nn.Module):
# Assume input size 20 and output size 10
def __init__(self):
super(NetworkB, self).__init__()
self.layer = nn.Linear(20, 10)
def forward(self, x):
return self.layer(x)
class NetworkC(nn.Module):
# Assume input size 20 and output size 10
def __init__(self):
super(NetworkC, self).__init__()
self.layer = nn.Linear(20, 10)
def forward(self, x):
return self.layer(x)
netA = NetworkA()
netB = NetworkB()
netC = NetworkC()
optimizer = optim.SGD(list(netA.parameters()) + list(netB.parameters()) + list(netC.parameters()), lr=0.001)
criterion = nn.MSELoss()
for epoch in range(100): # loop over the dataset multiple times
optimizer.zero_grad()
# get the inputs and labels
inputs, labels = torch.rand(10), torch.rand(10) # assume inputs and labels are PyTorch tensors
# forward + backward + optimize
output_A = netA(inputs)
output_B = netB(output_A)
output_C = netC(output_A)
loss_B = criterion(output_B, labels)
loss_C = criterion(output_C, labels)
loss = loss_B + loss_C
print(f"loss = {loss}")
loss.backward()
optimizer.step()