forked from tbwxmu/SAMPN
-
Notifications
You must be signed in to change notification settings - Fork 0
/
ol_wat2.py
135 lines (109 loc) · 4.82 KB
/
ol_wat2.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
from argparse import ArgumentParser, Namespace
import os, random
import torch
from torch.utils.data import DataLoader
from utils import load_checkpoint
from cv import DGLDataset, DGLCollator
from tqdm import tqdm
import numpy as np
from cv import evaluate_batch
from utils import get_metric_func
from scaler import StandardScaler, minmaxScaler
seed=3032
def seed_torch(seed=seed):
random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.enabled = False
print(f'used seed in seed_torch={seed}<+++++++++++')
def worker_init_fn():
np.random.seed(seed)
def visualize_attention(args: Namespace):
"""Visualizes attention weights."""
print(f'Loading model from "{args.checkpoint_path}"')
model = load_checkpoint(args.checkpoint_path, cuda=args.cuda)
mpn = model.encoder
print(f'mpn:-->{type(mpn)}')
print(f'MPNencoder attributes:{mpn.encoder.__dict__}')
print('Loading data')
if os.path.exists(args.data_path) and os.path.getsize(args.data_path) > 0:
DGLtest=args.data_path
print(f'Loading data -->{DGLtest}')
else:
direct = 'data_RE2/tmp/'
DGLtest=direct+'viz.csv'
print(f'Loading data -->{DGLtest}')
viz_data=DGLDataset(DGLtest,training=False)
viz_dataloader = DataLoader(viz_data, batch_size=args.batch_size,
shuffle=False, num_workers=0,
collate_fn=DGLCollator(training=False),
drop_last=False,
worker_init_fn=worker_init_fn)
metric_func = get_metric_func(metric=args.metric)
for it, result_batch in enumerate(tqdm(viz_dataloader)):
batch_sm = result_batch['sm']
label_batch=result_batch['labels']
if args.dataset_type == 'regression':
if args.scale=="standardization":
print('Fitting scaler(Z-score standardization)')
scaler = StandardScaler().fit(label_batch)
y_train_scaled = scaler.transform(label_batch)
print(f'train data mean:{scaler.means}\nstd:{scaler.stds}\n')
if args.scale=="normalization":
print('Fitting scaler( Min-Max normalization )')
scaler = minmaxScaler().fit(label_batch)
y_train_scaled = scaler.transform(label_batch)
print(f'train data min:{scaler.mins}\ntrain data max:{scaler.maxs}\n')
if args.scale !='standardization' and args.scale!='normalization':
raise ValueError("not implemented scaler,use one of [standardization, normalization]")
else:
scaler = None
mpn.viz_attention(batch_sm, viz_dir=args.viz_dir)
test_targets,test_preds,test_scores = evaluate_batch(args,
model=model,
data=viz_dataloader,
num_tasks=args.num_tasks,
metric_func=metric_func,
dataset_type=args.dataset_type,
scaler=scaler,
logger=None,
Foldth=0,
predsLog=args.save_dir)
print(f'rung viz{args.viz_dir}')
if __name__ == '__main__':
parser = ArgumentParser()
parser.add_argument('--data_path', type=str, default='data_RE2/tmp/0_testcc',
help='Path to data CSV file')
parser.add_argument('--viz_dir', type=str, default='viz_attention',
help='Path where attention PNGs will be saved')
parser.add_argument('--checkpoint_path', type=str, default='save_test/fold_0/model_0/model.pt',
help='Path to a model checkpoint')
parser.add_argument('--batch_size', type=int, default=50,
help='Batch size')
parser.add_argument('--no_cuda', action='store_true', default=False,
help='Turn off cuda')
args = parser.parse_args()
args.seed=seed
seed_torch(seed)
args.sumstyle=True
args.data_path='data_RE2/ol_wat.csv'
args.data_filename=os.path.basename(args.data_path)+f'_seed{args.seed}'
args.viz_dir='png_seed3032logp_ol'
args.cuda = not args.no_cuda and torch.cuda.is_available()
args.checkpoint_path='save_test/fold_0/model_0/LogP_moleculenet.csv_seed3032_model.pt'
args.batch_size=128
args.attention=True
args.dataset_type='regression'
args.scale="normalization"
args.num_tasks=1
args.metric='rmse'
args.save_dir='save_test'
del args.no_cuda
os.makedirs(args.viz_dir, exist_ok=True)
print(f'args:\t-->{args}')
visualize_attention(args)