-
Notifications
You must be signed in to change notification settings - Fork 0
/
main_reasoning_routine.py
152 lines (124 loc) · 6.37 KB
/
main_reasoning_routine.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
import sys
import json
from pycoral.utils.dataset import read_label_file
from params import get_parser
from postgresql.io import *
from postgresql.basic_queries import retrieve_new_anchor_measurements, mark_complete
from postgresql.spatial_queries import populate_with_boxes
from postgresql.size_queries import populate_with_sizes
from spatial_reasoner.spatial_reasoning import build_QSR_graph,spatial_validate
from KB.size_kb import extract_size_kb
from KB.spatial_kb import extract_spatial_kb
from KB.commonsense_rels import extract_csk
from KB.wordnet_linking import map_to_synsets
from DL.ranking_aggregation import merge_scored_ranks
from size_reasoner.size_reasoning import size_validate
from HS.graph_completion import complete_graph
from HS.HSrules import check_rules
from utils import plot_graph
def main():
# Load parameters and maintain as dictionary
parser = get_parser()
args_dict, unknown = parser.parse_known_args()
target_classes = read_label_file(args_dict.classes) #dictionary with format classid: classname
connection, cursor = connect_db(args_dict.dbuser, args_dict.dbname) # connect to db
print("class to WN synset mapping")
# dictionary of format classname: array of synsets
if args_dict.extract_synsets:
target_synsets = map_to_synsets(list(target_classes.values()), args_dict)
else:
with open(args_dict.syn_path) as fp: #load pre-extracted from local
target_synsets = json.load(fp)
print("Loaded size KB from local")
print("Retrieving background KBs")
if args_dict.extract_sizekb: #extract from scratch
print("Extracting size knowledge from scratch ... it will take long")
sizeKB = extract_size_kb(target_classes, args_dict)
else:
with open(args_dict.sizekb_path) as fp: #load pre-extracted from local
sizeKB = json.load(fp)
print("Loaded size KB from local")
if args_dict.extract_spatialkb: #extract from scratch
print("Extracting spatial knowledge from scratch ... it will take long")
spatialKB = extract_spatial_kb(connection, cursor, args_dict)
else:
with open(args_dict.spatialkb_path) as fp: #load pre-extracted from local
spatialKB = json.load(fp)
print("Loaded spatial KB from local")
if args_dict.extract_quasi: #extract from scratch
print("Extracting commonsense facts from Quasimodo... it will take long")
quasiKB = extract_csk(args_dict)
else:
with open(args_dict.quasikb_path) as fp: #load pre-extracted from local
quasiKB = json.load(fp)
print("Loaded relevant quasimodo facts from local")
#Loading H&S rules
with open(args_dict.rules_src) as fp: # load pre-extracted from local
rule_dict = json.load(fp)
print("Loaded set of H&S rules from local")
# retrieve new object anchors to be examined and all DL predictions related to each anchor
#i.e., either a newly added anchor or a former anchor for which a new measurement was recorded
anchor_dict = retrieve_new_anchor_measurements(cursor)
#If no new anchors, no reasoning
if not bool(anchor_dict): #if dict is empty
print("No new observations or anchors found... stopping reasoning routine")
return
populate_with_boxes(connection,cursor,sf=args_dict.sf) # compute size and spatial bboxes of union chull
populate_with_sizes(connection,cursor) # estimate anchor sizes (based on bbox)
print("Spatial DB completed with anchor bounding boxes and sizes")
print("Extracting observed QSR between objects")
qsr_graph = build_QSR_graph(connection, cursor, anchor_dict, args_dict) # extract QSR
print("Aggregating DL predictions across observations")
aggr_ranks = []
node_mapping={}
for a_id, attr in anchor_dict.items():
#Aggregate DL rankings on same anchor
#print(attr['DL_predictions'])
aggr_DL_rank = merge_scored_ranks(attr['DL_predictions'])
# annotate qsr graph with top1 DL pred
print(aggr_DL_rank[0][0])
topclass = target_classes[aggr_DL_rank[0][0]]
qsr_graph.nodes[a_id]["obj_label"] = topclass
# Select which anchors in anchor_dict need correction (based on aggr confidence)
topscore = aggr_DL_rank[0][1]
if topscore < args_dict.dlconf: tbcorr = True
else: tbcorr = False
aggr_ranks.append((a_id, aggr_DL_rank, tbcorr))
for a_id, aggr_DL_rank, corr_flag in aggr_ranks: #Only apply reasoning to those predictions that need correction
if corr_flag:
topclassDL = target_classes[aggr_DL_rank[0][0]]
print("Reasoning on object anchors")
# Validate merged ranking based on size KB
sizev_rank = size_validate(aggr_DL_rank, cursor, a_id, sizeKB)
# Validate ranking based on spatial KB
spatial_outrank = spatial_validate(a_id, aggr_DL_rank,sizev_rank, qsr_graph, spatialKB, target_synsets, meta=args_dict.meta)
read_input = [(target_classes[cid], score) for cid, score in spatial_outrank]
print("Ranking post meta-reasoning")
#if top class has changed also change label in scene graph
# annotate qsr graph with top1 DL pred
topclassmod = read_input[0][0]
if topclassmod != topclassDL:
qsr_graph.nodes[a_id]["obj_label"] = topclassmod
else:
print("DL confident enough, keeping original ranking")
read_input = [(target_classes[cid], score) for cid, score in aggr_DL_rank]
# print(read_input)
final_pred = qsr_graph.nodes[a_id]["obj_label"]
if not final_pred in node_mapping.keys():
node_mapping[final_pred]=[]
node_mapping[final_pred].append(a_id) # keep track of which anchors belong to a certain class, used later for rule checking
# Scene assessment part
# expand QSR graph based on Quasimodo concepts ./data/commonsense_extracted.json
plot_graph(qsr_graph)
mod_graph = complete_graph(qsr_graph,quasiKB)
plot_graph(mod_graph)
#relabel graph with object names
# check rules
check_rules(mod_graph, rule_dict, node_mapping)
# once done with reasoning, mark all object anchors as complete
mark_complete(connection, cursor, list(anchor_dict.keys()))
disconnect_DB(connection,cursor) #close database connection
# evaluate results
if __name__ == "__main__":
main()
sys.exit(0)