-
Notifications
You must be signed in to change notification settings - Fork 5
/
model.py
104 lines (85 loc) · 3.32 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import torch
import torch.nn as nn
class SiameseNet(nn.Module):
"""
A Convolutional Siamese Network for One-Shot Learning.
Siamese networts learn image representations via a supervised metric-based
approach. Once tuned, their learned features can be leveraged for one-shot
learning without any retraining.
References
----------
- Koch et al., https://www.cs.cmu.edu/~rsalakhu/papers/oneshot1.pdf
"""
def __init__(self):
super(SiameseNet, self).__init__()
self.conv = nn.Sequential(
nn.Conv2d(1, 64, 10), # 64@96*96
nn.ReLU(inplace=True),
nn.MaxPool2d(2), # 64@48*48
nn.Conv2d(64, 128, 7),
nn.ReLU(inplace=True), # 128@42*42
nn.MaxPool2d(2), # 128@21*21
nn.Conv2d(128, 128, 4),
nn.ReLU(inplace=True), # 128@18*18
nn.MaxPool2d(2), # 128@9*9
nn.Conv2d(128, 256, 4),
nn.ReLU(inplace=True), # 256@6*6
)
self.liner = nn.Sequential(nn.Linear(9216, 4096), nn.Sigmoid())
self.out = nn.Linear(4096, 1)
# weight init
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_uniform_(m.weight)
elif isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight)
def sub_forward(self, x):
"""
Forward pass the input image through 1 subnetwork.
Args
----
- x: a Variable of size (B, C, H, W). Contains either the first or
second image pair across the input batch.
Returns
-------
- out: a Variable of size (B, 4096). The hidden vector representation
of the input vector x.
"""
x = self.conv(x)
x = x.view(x.size()[0], -1)
x = self.liner(x)
return x
def forward(self, x1, x2):
"""
Forward pass the input image pairs through both subtwins. An image
pair is composed of a left tensor x1 and a right tensor x2.
Concretely, we compute the component-wise L1 distance of the hidden
representations generated by each subnetwork, and feed the difference
to a final fc-layer followed by a sigmoid activation function to
generate a similarity score in the range [0, 1] for both embeddings.
Args
----
- x1: a Variable of size (B, C, H, W). The left image pairs along the
batch dimension.
- x2: a Variable of size (B, C, H, W). The right image pairs along the
batch dimension.
Returns
-------
- probas: a Variable of size (B, 1). A probability scalar indicating
whether the left and right input pairs, along the batch dimension,
correspond to the same class. We expect the network to spit out
values near 1 when they belong to the same class, and 0 otherwise.
"""
# encode image pairs
h1 = self.sub_forward(x1)
h2 = self.sub_forward(x2)
# compute l1 distance
diff = torch.abs(h1 - h2)
# score the similarity between the 2 encodings
scores = self.out(diff)
# return scores (without sigmoid) and use bce_with_logit
# for increased numerical stability
return scores
if __name__ == '__main__':
net = SiameseNet()
print(net)