-
Notifications
You must be signed in to change notification settings - Fork 0
/
utils.py
74 lines (65 loc) · 2.27 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import sys
import os
import csv
import argparse
import random
from pathlib import Path
import numpy as np
import torch
import pandas as pd
import re
from torch.utils.data import DataLoader
try:
from torch_geometric.data import Batch
except ImportError:
pass
def set_seed(seed):
"""Sets seed"""
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
def move_to(obj, device):
if isinstance(obj, dict):
return {k: move_to(v, device) for k, v in obj.items()}
elif isinstance(obj, list):
return [move_to(v, device) for v in obj]
elif isinstance(obj, float) or isinstance(obj, int):
return obj
else:
# Assume obj is a Tensor or other type
# (like Batch, for MolPCBA) that supports .to(device)
return obj.to(device)
def detach_and_clone(obj):
if torch.is_tensor(obj):
return obj.detach().clone()
elif isinstance(obj, dict):
return {k: detach_and_clone(v) for k, v in obj.items()}
elif isinstance(obj, list):
return [detach_and_clone(v) for v in obj]
elif isinstance(obj, float) or isinstance(obj, int):
return obj
else:
raise TypeError("Invalid type for detach_and_clone")
def collate_list(vec):
"""
If vec is a list of Tensors, it concatenates them all along the first dimension.
If vec is a list of lists, it joins these lists together, but does not attempt to
recursively collate. This allows each element of the list to be, e.g., its own dict.
If vec is a list of dicts (with the same keys in each dict), it returns a single dict
with the same keys. For each key, it recursively collates all entries in the list.
"""
if not isinstance(vec, list):
raise TypeError("collate_list must take in a list")
elem = vec[0]
if torch.is_tensor(elem):
return torch.cat(vec)
elif isinstance(elem, list):
return [obj for sublist in vec for obj in sublist]
elif isinstance(elem, dict):
return {k: collate_list([d[k] for d in vec]) for k in elem}
else:
raise TypeError("Elements of the list to collate must be tensors or dicts.")