-
Notifications
You must be signed in to change notification settings - Fork 8
/
Copy pathlast_layer.py
35 lines (33 loc) · 1.21 KB
/
last_layer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
import torch
import torch.utils.data
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.autograd import Variable
import numpy as np
import matplotlib.pyplot as plt
import cv2
from PIL import Image
from dataHelper import DatasetFolder
import re
import numpy as np
import os
import copy
from skimage.transform import resize
from helpers import makedir, find_high_activation_crop
import model
import push
import train_and_test as tnt
import save
from log import create_logger
from preprocess import mean, std, preprocess_input_function, undo_preprocess_input_function
def show_last_layer_connections(ppnet):
print(ppnet.num_prototypes, ppnet.num_classes)
last_layer_connections = np.zeros((ppnet.num_prototypes, ppnet.num_classes))
last_layer_connections = ppnet.last_layer.weight
return last_layer_connections
def show_last_layer_connections_T(ppnet):
print(ppnet.num_prototypes, ppnet.num_classes)
last_layer_connections = np.zeros((ppnet.num_prototypes, ppnet.num_classes))
last_layer_connections = ppnet.last_layer.weight
last_layer_connections_T = torch.transpose(last_layer_connections, 0, 1)
return last_layer_connections_T