-
Notifications
You must be signed in to change notification settings - Fork 15
/
utils.py
70 lines (56 loc) · 2.14 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import json
import os
from datetime import datetime
from time import time
import git
import torch
from consts import NULL_ID_FOR_COREF
def flatten_list_of_lists(lst):
return [elem for sublst in lst for elem in sublst]
def extract_clusters(gold_clusters):
gold_clusters = [tuple(tuple(m) for m in gc if NULL_ID_FOR_COREF not in m) for gc in gold_clusters.tolist()]
gold_clusters = [cluster for cluster in gold_clusters if len(cluster) > 0]
return gold_clusters
def extract_mentions_to_predicted_clusters_from_clusters(gold_clusters):
mention_to_gold = {}
for gc in gold_clusters:
for mention in gc:
mention_to_gold[tuple(mention)] = gc
return mention_to_gold
def extract_clusters_for_decode(mention_to_antecedent):
mention_to_antecedent = sorted(mention_to_antecedent)
mention_to_cluster = {}
clusters = []
for mention, antecedent in mention_to_antecedent:
if antecedent in mention_to_cluster:
cluster_idx = mention_to_cluster[antecedent]
clusters[cluster_idx].append(mention)
mention_to_cluster[mention] = cluster_idx
else:
cluster_idx = len(clusters)
mention_to_cluster[mention] = cluster_idx
mention_to_cluster[antecedent] = cluster_idx
clusters.append([antecedent, mention])
clusters = [tuple(cluster) for cluster in clusters]
return clusters, mention_to_cluster
def mask_tensor(t, mask):
t = t + ((1.0 - mask.float()) * -10000.0)
t = torch.clamp(t, min=-10000.0, max=10000.0)
return t
def write_meta_data(output_dir, args):
output_path = os.path.join(output_dir, "meta.json")
repo = git.Repo(search_parent_directories=True)
hexsha = repo.head.commit.hexsha
ts = time()
print(f"Writing {output_path}")
with open(output_path, mode='w') as f:
json.dump(
{
'git_hexsha': hexsha,
'args': {k: str(v) for k, v in args.__dict__.items()},
'date': datetime.fromtimestamp(ts).strftime('%Y-%m-%d %H:%M:%S')
},
f,
indent=4,
sort_keys=True)
print(file=f)