-
Notifications
You must be signed in to change notification settings - Fork 18
/
Copy pathpassport_generator.py
43 lines (30 loc) · 1.09 KB
/
passport_generator.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
import random
import torch
def get_key(dataset_loader, n=32):
dataset = dataset_loader.dataset
indices = random.sample(range(len(dataset)), n)
imgs = []
for i in indices:
img, target = dataset[i]
imgs.append(img.unsqueeze(0))
return torch.cat(imgs, dim=0), indices
def get_intermediate_key(input_key, intermediate_key_name, pretrained_model):
x = input_key
with torch.no_grad():
for i, m in enumerate(pretrained_model.features):
if 'features.' + str(i) == intermediate_key_name:
return x
x = m(x)
def set_key(pretrained_model, target_model,
key_x, key_y, ind=None):
print('Setting keys')
if len(key_x.size()) == 3:
key_x = key_x.unsqueeze(0)
if key_y is not None and len(key_y.size()) == 3:
key_y = key_y.unsqueeze(0)
print('Key size', key_x.size())
if ind is not None:
target_model.set_intermediate_keys(pretrained_model, key_x, key_y, ind)
else:
target_model.set_intermediate_keys(pretrained_model, key_x, key_y)
print('Key is set!')