-
Notifications
You must be signed in to change notification settings - Fork 0
/
utils.py
75 lines (63 loc) · 2.24 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import pickle
import numpy as np
import random
import torch
import os
import sys
TMP_DIR = {
'esc': './tmp/esc',
'cima': './tmp/cima',
'cb': './tmp/cb',
}
def set_random_seed(seed):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
# Disable
def blockPrint():
sys.stdout = open(os.devnull, 'w')
# Restore
def enablePrint():
sys.stdout = sys.__stdout__
def load_dataset(data_name):
dataset = {'train':[], 'test':[], 'valid':[]}
for key in dataset:
with open("../data/%s-%s.txt"%(data_name, key),'r') as infile:
for line in infile:
dataset[key].append(eval(line.strip('\n')))
return dataset
def set_cuda(args):
use_cuda = torch.cuda.is_available()
if use_cuda:
torch.cuda.manual_seed(args.seed)
torch.backends.cudnn.deterministic = True
devices_id = [int(device_id) for device_id in args.gpu.split()]
device = (
torch.device("cuda:{}".format(str(devices_id[0])))
if use_cuda
else torch.device("cpu")
)
return device, devices_id
def save_rl_mtric(dataset, filename, epoch, SR, mode='train'):
PATH = TMP_DIR[dataset] + '/eval_result/' + filename + '.txt'
if not os.path.isdir(TMP_DIR[dataset] + '/eval_result/'):
os.makedirs(TMP_DIR[dataset] + '/eval_result/')
if mode == 'train':
with open(PATH, 'a') as f:
f.write('===========Train===============\n')
f.write('Starting {} user epochs\n'.format(epoch))
f.write('training SR: {}\n'.format(SR[0]))
f.write('training Avg@T: {}\n'.format(SR[1]))
f.write('training Rewards: {}\n'.format(SR[2]))
f.write('================================\n')
# f.write('1000 loss: {}\n'.format(loss_1000))
elif mode == 'test':
with open(PATH, 'a') as f:
f.write('===========Test===============\n')
f.write('Testing {} user tuples\n'.format(epoch))
f.write('Testing SR: {}\n'.format(SR[0]))
f.write('Testing Avg@T: {}\n'.format(SR[1]))
f.write('Testing Rewards: {}\n'.format(SR[2]))
f.write('================================\n')