-
Notifications
You must be signed in to change notification settings - Fork 11
/
loop_bounds.py
929 lines (804 loc) · 31.3 KB
/
loop_bounds.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
#
# Copyright 2020, Data61, CSIRO (ABN 41 687 119 230)
#
# SPDX-License-Identifier: BSD-2-Clause
#
import check,search,problem,syntax,solver,logic,rep_graph,re
from rep_graph import vc_num, vc_offs, pc_true_hyp, Hyp, eq_hyp
from target_objects import functions, trace
from check import restr_others,split_visit_one_visit
from problem import inline_at_point
from syntax import mk_not, true_term, false_term, mk_implies, Expr, Type, unspecified_precond_term,mk_and
from rep_graph import mk_graph_slice, VisitCount, to_smt_expr
from search import eval_model_expr
import target_objects
import trace_refute
import stack_logic
import time
#tryFun must take exactly 1 argument
def downBinSearch(minimum, maximum, tryFun):
upperBound = maximum
lowerBound = minimum
while upperBound > lowerBound:
print 'searching in %d - %d' % (lowerBound,upperBound)
cur = (lowerBound + upperBound) / 2
if tryFun(cur):
upperBound = cur
else:
lowerBound = cur + 1
assert upperBound == lowerBound
ret = lowerBound
return ret
def upDownBinSearch (minimum, maximum, tryFun):
"""performs a binary search between minimum and maximum, but does not start
in the middle. instead it does a binary escalation up from the minimum
first. this makes sense for ranges e.g. 2 - 1000000 where the bound is
likely to be near the bottom of the range. it also avoids testing values
more than twice as high as the bound, which may avoid some issues."""
upperBound = 2 * minimum
while upperBound < maximum:
if tryFun (upperBound):
return downBinSearch (minimum, upperBound, tryFun)
else:
upperBound *= 2
if tryFun (maximum):
return downBinSearch (minimum, maximum, tryFun)
else:
return None
def addr_of_node (preds, n):
while not trace_refute.is_addr (n):
[n] = preds[n]
return n
def all_asm_functions ():
ss = stack_logic.get_functions_with_tag ('ASM')
return [s for s in ss if not stack_logic.is_instruction (s)]
call_site_set = {}
def build_call_site_set ():
for f in all_asm_functions ():
preds = logic.compute_preds (functions[f].nodes)
for (n, node) in functions[f].nodes.iteritems ():
if node.kind == 'Call':
s = call_site_set.setdefault (node.fname, set ())
s.add (addr_of_node (preds, n))
call_site_set[('IsLoaded', None)] = True
def all_call_sites (f):
if not call_site_set:
build_call_site_set ()
return list (call_site_set.get (f, []))
#naive binary search to find loop bounds
def findLoopBoundBS(p_n, p, restrs=None, hyps=None, try_seq=None):
if hyps == None:
hyps = []
#print 'restrs: %s' % str(restrs)
if try_seq == None:
#bound_try_seq = [1,2,3,4,5,10,50,130,200,260]
#bound_try_seq = [0,1,2,3,4,5,10,50,260]
calls = [n for n in p.loop_body (p_n) if p.nodes[n].kind == 'Call']
if calls:
bound_try_seq = [0,1,20]
else:
bound_try_seq = [0,1,20,34]
else:
bound_try_seq = try_seq
rep = mk_graph_slice (p, fast = True)
#get the head
#print 'Binary addr: %s' % toHexs(self.toPhyAddrs(p_loop_heads))
loop_bound = None
p_loop_heads = [n for n in p.loop_data if p.loop_data[n][0] == 'Head']
print 'p_loop_heads: %s' % p_loop_heads
if restrs == None:
others = [x for x in p_loop_heads if not x == p_n]
#vc_options([concrete numbers], [offsets])
restrs = tuple( [(n2, rep_graph.vc_options([0],[1])) for n2 in others] )
print 'getting the initial bound'
#try:
index = tryLoopBound(p_n,p,bound_try_seq,rep,restrs = restrs, hyps=hyps)
if index == -1:
return None
print 'got the initial bound %d' % bound_try_seq[index]
#do a downward binary search to find the concrete loop bound
if index == 0:
loop_bound = bound_try_seq[0]
print 'bound = %d' % loop_bound
return loop_bound
loop_bound = downBinSearch(bound_try_seq[index-1], bound_try_seq[index], lambda x: tryLoopBound(p_n,p,[x],rep,restrs=restrs, hyps=hyps, bin_return=True))
print 'bound = %d' % loop_bound
return loop_bound
def default_n_vc_cases (p, n):
head = p.loop_id (n)
general = [(n2, rep_graph.vc_options ([0], [1]))
for n2 in p.loop_heads ()
if n2 != head]
if head:
return [(n, tuple (general + [(head, rep_graph.vc_num (1))])),
(n, tuple (general + [(head, rep_graph.vc_offs (1))]))]
specific = [(head, rep_graph.vc_offs (1)) for _ in [1] if head]
return [(n, tuple (general + specific))]
def callNodes(p, fs= None):
ns = [n for n in p.nodes if p.nodes[n].kind == 'Call']
if fs != None:
ns = [n for n in ns if p.nodes[n].fname in fs]
return ns
def noHaltHyps(split,p):
ret = []
all_halts = callNodes(p,fs=['halt'])
for x in all_halts:
ret += [rep_graph.pc_false_hyp((n_vc, p.node_tags[x][0]))
for n_vc in default_n_vc_cases (p, x)]
return ret
def tryLoopBound(p_n, p, bounds,rep,restrs =None, hints =None,kind = 'Number',bin_return = False,hyps = None):
if restrs == None:
restrs = ()
if hints == None:
hints = []
if hyps == None:
hyps = []
tag = p.node_tags[p_n][0]
from stack_logic import default_n_vc
print 'trying bound: %s' % bounds
ret_bounds = []
for (index,i) in enumerate(bounds):
print 'testing %d' % i
restrs2 = restrs + ((p_n, VisitCount (kind, i)), )
try:
pc = rep.get_pc ((p_n, restrs2))
except:
print 'get_pc failed'
if bin_return:
return False
else:
return -1
#print 'got rep_.get_pc'
restrs3 = restr_others (p, restrs2, 2)
epc = rep.get_pc (('Err', restrs3), tag = tag)
hyp = mk_implies (mk_not (epc), mk_not (pc))
hyps = hyps + noHaltHyps(p_n,p)
#hyps = []
#print 'calling test_hyp_whyps'
if rep.test_hyp_whyps (hyp, hyps):
print 'p_n %d: split limit found: %d' % (p_n, i)
if bin_return:
return True
return index
if bin_return:
return False
print 'loop bound not found!'
return -1
assert False, 'failed to find loop bound for p_n %d' % p_n
def get_linear_series_eqs (p, split, restrs, hyps, omit_standard = False):
k = ('linear_series_eqs', split, restrs, tuple (hyps))
if k in p.cached_analysis:
if omit_standard:
standard = set (search.mk_seq_eqs (p, split, 1, with_rodata = False))
return set (p.cached_analysis[k]) - standard
return p.cached_analysis[k]
cands = search.mk_seq_eqs (p, split, 1, with_rodata = False)
cands += candidate_additional_eqs (p, split)
(tag, _) = p.node_tags[split]
rep = rep_graph.mk_graph_slice (p, fast = True)
def do_checks (eqs_assume, eqs):
checks = (check.single_loop_induct_step_checks (p, restrs, hyps, tag,
split, 1, eqs, eqs_assume = eqs_assume)
+ check.single_loop_induct_base_checks (p, restrs, hyps, tag,
split, 1, eqs))
groups = check.proof_check_groups (checks)
for group in groups:
(res, _) = check.test_hyp_group (rep, group)
if not res:
return False
return True
eqs = []
failed = []
while cands:
cand = cands.pop ()
if do_checks (eqs, [cand]):
eqs.append (cand)
failed.reverse ()
cands = failed + cands
failed = []
else:
failed.append (cand)
assert do_checks ([], eqs)
p.cached_analysis[k] = eqs
if omit_standard:
standard = set (search.mk_seq_eqs (p, split, 1, with_rodata = False))
return set (eqs) - standard
return eqs
def get_linear_series_hyps (p, split, restrs, hyps):
eqs = get_linear_series_eqs (p, split, restrs, hyps)
(tag, _) = p.node_tags[split]
hyps = [h for (h, _) in linear_eq_hyps_at_visit (tag, split, eqs,
restrs, vc_offs (0))]
return hyps
def is_zero (expr):
return expr.kind == 'Num' and expr.val & ((1 << expr.typ.num) - 1) == 0
def candidate_additional_eqs (p, split):
eq_vals = set ()
def visitor (expr):
if expr.is_op ('Equals') and expr.vals[0].typ.kind == 'Word':
[x, y] = expr.vals
eq_vals.update ([(x, y), (y, x)])
for n in p.loop_body (split):
p.nodes[n].visit (lambda x: (), visitor)
for (x, y) in list (eq_vals):
if is_zero (x) and y.is_op ('Plus'):
[x, y] = y.vals
eq_vals.add ((x, syntax.mk_uminus (y)))
eq_vals.add ((y, syntax.mk_uminus (x)))
elif is_zero (x) and y.is_op ('Minus'):
[x, y] = y.vals
eq_vals.add ((x, y))
eq_vals.add ((y, x))
loop = syntax.mk_var ('%i', syntax.word32T)
minus_loop_step = syntax.mk_uminus (loop)
vas = search.get_loop_var_analysis_at(p, split)
ls_vas = dict ([(var, [data]) for (var, data) in vas
if data[0] == 'LoopLinearSeries'])
cmp_series = [(x, y, rew, offs) for (x, y) in eq_vals
for (_, rew, offs) in ls_vas.get (x, [])]
odd_eqs = []
for (x, y, rew, offs) in cmp_series:
x_init_cmp1 = syntax.mk_less_eq (x, rew (x, minus_loop_step))
x_init_cmp2 = syntax.mk_less_eq (rew (x, minus_loop_step), x)
fin_cmp1 = syntax.mk_less (x, y)
fin_cmp2 = syntax.mk_less (y, x)
odd_eqs.append (syntax.mk_eq (x_init_cmp1, fin_cmp1))
odd_eqs.append (syntax.mk_eq (x_init_cmp2, fin_cmp1))
odd_eqs.append (syntax.mk_eq (x_init_cmp1, fin_cmp2))
odd_eqs.append (syntax.mk_eq (x_init_cmp2, fin_cmp2))
ass_eqs = []
var_deps = p.compute_var_dependencies ()
for hook in target_objects.hooks ('extra_wcet_assertions'):
for assn in hook (var_deps[split]):
ass_eqs.append (assn)
return odd_eqs + ass_eqs
extra_loop_consts = [2 ** 16]
call_ctxt_problems = []
avoid_C_information = [False]
def get_call_ctxt_problem (split, call_ctxt, timing = True):
# time this for diagnostic reasons
start = time.time ()
from trace_refute import identify_function, build_compound_problem_with_links
f = identify_function (call_ctxt, [split])
for (ctxt2, p, hyps, addr_map) in call_ctxt_problems:
if ctxt2 == (call_ctxt, f):
return (p, hyps, addr_map)
(p, hyps, addr_map) = build_compound_problem_with_links (call_ctxt, f)
if avoid_C_information[0]:
hyps = [h for h in hyps if not has_C_information (p, h)]
call_ctxt_problems.append(((call_ctxt, f), p, hyps, addr_map))
del call_ctxt_problems[: -20]
end = time.time ()
if timing:
save_extra_timing ('GetProblem', call_ctxt + [split], end - start)
return (p, hyps, addr_map)
def has_C_information (p, hyp):
for (n_vc, tag) in hyp.visits ():
if not p.hook_tag_hints.get (tag, None) == 'ASM':
return True
known_bound_restr_hyps = {}
known_bounds = {}
def serialise_bound (addr, bound_info):
if bound_info == None:
return [hex(addr), "None", "None"]
else:
(bound, kind) = bound_info
assert logic.is_int (bound)
assert str (kind) == kind
return [hex (addr), str (bound), kind]
def save_bound (glob, split_bin_addr, call_ctxt, prob_hash, prev_bounds, bound,
time = None):
f_names = [trace_refute.get_body_addrs_fun (x)
for x in call_ctxt + [split_bin_addr]]
loop_name = '<%s>' % ' -> '.join (f_names)
comment = '# bound for loop in %s:' % loop_name
ss = ['LoopBound'] + serialise_bound (split_bin_addr, bound)
if glob:
ss[0] = 'GlobalLoopBound'
ss += [str (len (call_ctxt))] + map (hex, call_ctxt)
ss += [str (prob_hash)]
if glob:
assert prev_bounds == None
else:
ss += [str (len (prev_bounds))]
for (split, bound) in prev_bounds:
ss += serialise_bound (split, bound)
s = ' '.join (ss)
f = open ('%s/LoopBounds.txt' % target_objects.target_dir, 'a')
f.write (comment + '\n')
f.write (s + '\n')
if time != None:
ctxt2 = call_ctxt + [split_bin_addr]
ctxt2 = ' '.join ([str (len (ctxt2))] + map (hex, ctxt2))
f.write ('LoopBoundTiming %s %s\n' % (ctxt2, time))
f.close ()
trace ('Found bound %s for 0x%x in %s.' % (bound, split_bin_addr,
loop_name))
def save_extra_timing (nm, ctxt, time):
ss = ['ExtraTiming', nm, str (len (ctxt))] + map (hex, ctxt) + [str(time)]
f = open ('%s/LoopBounds.txt' % target_objects.target_dir, 'a')
f.write (' '.join (ss) + '\n')
f.close ()
def parse_bound (ss, n):
addr = syntax.parse_int (ss[n])
bound = ss[n + 1]
if bound == 'None':
bound = None
return (n + 3, (addr, None))
else:
bound = syntax.parse_int (bound)
kind = ss[n + 2]
return (n + 3, (addr, (bound, kind)))
def parse_ctxt_id (bits, n):
return (n + 1, syntax.parse_int (bits[n]))
def parse_ctxt (bits, n):
return syntax.parse_list (parse_ctxt_id, bits, n)
def load_bounds ():
try:
f = open ('%s/LoopBounds.txt' % target_objects.target_dir)
ls = list (f)
f.close ()
except IOError, e:
ls = []
from syntax import parse_int, parse_list
for l in ls:
bits = l.split ()
if bits[:1] not in [['LoopBound'], ['GlobalLoopBound']]:
continue
(n, (addr, bound)) = parse_bound (bits, 1)
(n, ctxt) = parse_ctxt (bits, n)
prob_hash = parse_int (bits[n])
n += 1
if bits[0] == 'LoopBound':
(n, prev_bounds) = parse_list (parse_bound, bits, n)
assert n == len (bits), bits
known = known_bounds.setdefault (addr, [])
known.append ((ctxt, prob_hash, prev_bounds, bound))
else:
assert n == len (bits), bits
known = known_bounds.setdefault ((addr, 'Global'), [])
known.append ((ctxt, prob_hash, bound))
known_bounds['Loaded'] = True
def get_bound_ctxt (split, call_ctxt, use_cache = True):
trace ('Getting bound for 0x%x in context %s.' % (split, call_ctxt))
(p, hyps, addr_map) = get_call_ctxt_problem (split, call_ctxt)
orig_split = split
split = p.loop_id (addr_map[split])
assert split, (orig_split, call_ctxt)
split_bin_addr = min ([addr for addr in addr_map
if p.loop_id (addr_map[addr]) == split])
prior = get_prior_loop_heads (p, split)
restrs = ()
prev_bounds = []
for split2 in prior:
# recursion!
split2 = p.loop_id (split2)
assert split2
addr = min ([addr for addr in addr_map
if p.loop_id (addr_map[addr]) == split2])
bound = get_bound_ctxt (addr, call_ctxt)
prev_bounds.append ((addr, bound))
k = (p.name, split2, bound, restrs, tuple (hyps))
if k in known_bound_restr_hyps:
(restrs, hyps) = known_bound_restr_hyps[k]
else:
(restrs, hyps) = add_loop_bound_restrs_hyps (p, restrs, hyps,
split2, bound, call_ctxt + [orig_split])
known_bound_restr_hyps[k] = (restrs, hyps)
# start timing now. we miss some setup time, but it avoids double counting
# the recursive searches.
start = time.time ()
p_h = problem_hash (p)
prev_bounds = sorted (prev_bounds)
if not known_bounds:
load_bounds ()
known = known_bounds.get (split_bin_addr, [])
for (call_ctxt2, h, prev_bounds2, bound) in known:
match = (not call_ctxt2 or call_ctxt[- len (call_ctxt2):] == call_ctxt2)
if match and use_cache and h == p_h and prev_bounds2 == prev_bounds:
return bound
bound = search_bin_bound (p, restrs, hyps, split)
known = known_bounds.setdefault (split_bin_addr, [])
known.append ((call_ctxt, p_h, prev_bounds, bound))
end = time.time ()
save_bound (False, split_bin_addr, call_ctxt, p_h, prev_bounds, bound,
time = end - start)
return bound
def problem_hash (p):
return syntax.hash_tuplify ([p.name, p.entries,
sorted (p.outputs.iteritems ()), sorted (p.nodes.iteritems ())])
def search_bin_bound (p, restrs, hyps, split):
trace ('Searching for bound for 0x%x in %s.', (split, p.name))
bound = search_bound (p, restrs, hyps, split)
if bound:
return bound
# try to use a bound inferred from C
if avoid_C_information[0]:
# OK told not to
return None
if get_prior_loop_heads (p, split):
# too difficult for now
return None
asm_tag = p.node_tags[split][0]
(_, fname, _) = p.get_entry_details (asm_tag)
funs = [f for pair in target_objects.pairings[fname]
for f in pair.funs.values ()]
c_tags = [tag for tag in p.tags ()
if p.get_entry_details (tag)[1] in funs and tag != asm_tag]
if len (c_tags) != 1:
print 'Surprised to see multiple matching tags %s' % c_tags
return None
[c_tag] = c_tags
rep = rep_graph.mk_graph_slice (p)
if len (search.get_loop_entry_sites (rep, restrs, hyps, split)) != 1:
# technical, but it's not going to work in this case
return None
return getBinaryBoundFromC (p, c_tag, split, restrs, hyps)
def rab_test ():
[split_bin_addr] = get_loop_heads (functions['resolveAddressBits'])
(p, hyps, addr_map) = get_call_ctxt_problem (split_bin_addr, [])
split = p.loop_id (addr_map[split_bin_addr])
[c_tag] = [tag for tag in p.tags () if tag != p.node_tags[split][0]]
return getBinaryBoundFromC (p, c_tag, split, (), hyps)
last_search_bound = [0]
def search_bound (p, restrs, hyps, split):
last_search_bound[0] = (p, restrs, hyps, split)
# try a naive bin search first
# limit this to a small bound for time purposes
# - for larger bounds the less naive approach can be faster
bound = findLoopBoundBS(split, p, restrs=restrs, hyps=hyps,
try_seq = [0, 1, 6])
if bound != None:
return (bound, 'NaiveBinSearch')
l_hyps = get_linear_series_hyps (p, split, restrs, hyps)
rep = rep_graph.mk_graph_slice (p, fast = True)
def test (n):
assert n > 10
hyp = check.mk_loop_counter_eq_hyp (p, split, restrs, n - 2)
visit = ((split, vc_offs (2)), ) + restrs
continue_to_split_guess = rep.get_pc ((split, visit))
return rep.test_hyp_whyps (syntax.mk_not (continue_to_split_guess),
[hyp] + l_hyps + hyps)
# findLoopBoundBS always checks to at least 16
min_bound = 16
max_bound = max_acceptable_bound[0]
bound = upDownBinSearch (min_bound, max_bound, test)
if bound != None and test (bound):
return (bound, 'InductiveBinSearch')
# let the naive bin search go a bit further
bound = findLoopBoundBS(split, p, restrs=restrs, hyps=hyps)
if bound != None:
return (bound, 'NaiveBinSearch')
return None
def getBinaryBoundFromC (p, c_tag, asm_split, restrs, hyps):
c_heads = [h for h in search.init_loops_to_split (p, restrs)
if p.node_tags[h][0] == c_tag]
c_bounds = [(p.loop_id (split), search_bound (p, (), hyps, split))
for split in c_heads]
if not [b for (n, b) in c_bounds if b]:
trace ('no C bounds found (%s).' % c_bounds)
return None
asm_tag = p.node_tags[asm_split][0]
rep = rep_graph.mk_graph_slice (p)
i_seq_opts = [(0, 1), (1, 1), (2, 1)]
j_seq_opts = [(0, 1), (0, 2), (1, 1)]
tags = [p.node_tags[asm_split][0], c_tag]
try:
split = search.find_split (rep, asm_split, restrs, hyps, i_seq_opts,
j_seq_opts, 5, tags = [asm_tag, c_tag])
except solver.SolverFailure, e:
return None
if not split or split[0] != 'Split':
trace ('no split found (%s).' % repr (split))
return None
(_, split) = split
rep = rep_graph.mk_graph_slice (p)
checks = check.split_checks (p, (), hyps, split, tags = [asm_tag, c_tag])
groups = check.proof_check_groups (checks)
try:
for group in groups:
(res, el) = check.test_hyp_group (rep, group)
if not res:
trace ('split check failed!')
trace ('failed at %s' % el)
return None
except solver.SolverFailure, e:
return None
(as_details, c_details, _, n, _) = split
(c_split, (seq_start, step), _) = c_details
c_bound = dict (c_bounds).get (p.loop_id (c_split))
if not c_bound:
trace ('key split was not bounded (%r, %r).' % (c_split, c_bounds))
return None
(c_bound, _) = c_bound
max_it = (c_bound - seq_start) / step
assert max_it > n, (max_it, n)
(_, (seq_start, step), _) = as_details
as_bound = seq_start + (max_it * step)
# increment by 1 as this may be a bound for a different splitting point
# which occurs later in the loop
as_bound += 1
return (as_bound, 'FromC')
def get_prior_loop_heads (p, split, use_rep = None):
if use_rep:
rep = use_rep
else:
rep = rep_graph.mk_graph_slice (p)
prior = []
split = p.loop_id (split)
for h in p.loop_heads ():
s = set (prior)
if h not in s and rep.get_reachable (h, split) and h != split:
# need to recurse to ensure prior are in order
prior2 = get_prior_loop_heads (p, h, use_rep = rep)
prior.extend ([h2 for h2 in prior2 if h2 not in s])
prior.append (h)
return prior
def add_loop_bound_restrs_hyps (p, restrs, hyps, split, bound, ctxt):
# time this for diagnostic reasons
start = time.time ()
#vc_options([concrete numbers], [offsets])
hyps = hyps + get_linear_series_hyps (p, split, restrs, hyps)
hyps = list (set (hyps))
if bound == None or bound >= 10:
restrs = restrs + ((split, rep_graph.vc_options([0],[1])),)
else:
restrs = restrs + ((split, rep_graph.vc_upto (bound+1)),)
end = time.time ()
save_extra_timing ('LoopBoundRestrHyps', ctxt, end - start)
return (restrs, hyps)
max_acceptable_bound = [1000000]
functions_hash = [None]
def get_functions_hash ():
if functions_hash[0] != None:
return functions_hash[0]
h = hash (tuple (sorted ([(f, hash (functions[f])) for f in functions])))
functions_hash[0] = h
return h
addr_to_loop_id_cache = {}
complex_loop_id_cache = {}
def addr_to_loop_id (split):
if split not in addr_to_loop_id_cache:
add_fun_to_loop_data_cache (trace_refute.get_body_addrs_fun (split))
return addr_to_loop_id_cache[split]
def is_complex_loop (split):
split = addr_to_loop_id (split)
if split not in complex_loop_id_cache:
add_fun_to_loop_data_cache (trace_refute.get_body_addrs_fun (split))
return complex_loop_id_cache[split]
def get_loop_addrs (split):
split = addr_to_loop_id (split)
f = functions[trace_refute.get_body_addrs_fun (split)]
return [addr for addr in f.nodes if trace_refute.is_addr (addr)
if addr_to_loop_id_cache.get (addr) == split]
def add_fun_to_loop_data_cache (fname):
p = functions[fname].as_problem (problem.Problem)
p.do_loop_analysis ()
for h in p.loop_heads ():
addrs = [n for n in p.loop_body (h)
if trace_refute.is_addr (n)]
min_addr = min (addrs)
for addr in addrs:
addr_to_loop_id_cache[addr] = min_addr
complex_loop_id_cache[min_addr] = problem.has_inner_loop (p, h)
return min_addr
def get_bound_super_ctxt (split, call_ctxt, no_splitting=False,
known_bound_only=False):
if not known_bounds:
load_bounds ()
for (ctxt2, fn_hash, bound) in known_bounds.get ((split, 'Global'), []):
if ctxt2 == call_ctxt and fn_hash == get_functions_hash ():
return bound
min_loop_addr = addr_to_loop_id (split)
if min_loop_addr != split:
return get_bound_super_ctxt (min_loop_addr, call_ctxt,
no_splitting = no_splitting, known_bound_only = known_bound_only)
if known_bound_only:
return None
no_splitting_abort = [False]
try:
bound = get_bound_super_ctxt_inner (split, call_ctxt,
no_splitting = (no_splitting, no_splitting_abort))
except problem.Abort, e:
bound = None
if no_splitting_abort[0]:
# don't record this bound, since it might change if splitting was allowed
return bound
known = known_bounds.setdefault ((split, 'Global'), [])
known.append ((call_ctxt, get_functions_hash (), bound))
save_bound (True, split, call_ctxt, get_functions_hash (), None, bound)
return bound
from trace_refute import (function_limit, ctxt_within_function_limits)
def call_ctxt_computable (split, call_ctxt):
fs = [trace_refute.identify_function ([], [call_site])
for call_site in call_ctxt]
non_computable = [f for f in fs if trace_refute.has_complex_loop (f)]
if non_computable:
trace ('avoiding functions with complex loops: %s' % non_computable)
return not non_computable
def get_bound_super_ctxt_inner (split, call_ctxt,
no_splitting = (False, None)):
first_f = trace_refute.identify_function ([], (call_ctxt + [split])[:1])
call_sites = all_call_sites (first_f)
if function_limit (first_f) == 0:
return (0, 'FunctionLimit')
safe_call_sites = [cs for cs in call_sites
if ctxt_within_function_limits ([cs] + call_ctxt)]
if call_sites and not safe_call_sites:
return (0, 'FunctionLimit')
if len (call_ctxt) < 3 and len (safe_call_sites) == 1:
call_ctxt2 = list (safe_call_sites) + call_ctxt
if call_ctxt_computable (split, call_ctxt2):
trace ('using unique calling context %s' % str ((split, call_ctxt2)))
return get_bound_super_ctxt (split, call_ctxt2)
fname = trace_refute.identify_function (call_ctxt, [split])
bound = function_limit_bound (fname, split)
if bound:
return bound
bound = get_bound_ctxt (split, call_ctxt)
if bound:
return bound
trace ('no bound found immediately.')
if no_splitting[0]:
assert no_splitting[1], no_splitting
no_splitting[1][0] = True
trace ('cannot split by context (recursion).')
return None
# try to split over potential call sites
if len (call_ctxt) >= 3:
trace ('cannot split by context (context depth).')
return None
if len (call_sites) == 0:
# either entry point or nonsense
trace ('cannot split by context (reached top level).')
return None
problem_sites = [call_site for call_site in safe_call_sites
if not call_ctxt_computable (split, [call_site] + call_ctxt)]
if problem_sites:
trace ('cannot split by context (issues in %s).' % problem_sites)
return None
anc_bounds = [get_bound_super_ctxt (split, [call_site] + call_ctxt,
no_splitting = True)
for call_site in safe_call_sites]
if None in anc_bounds:
return None
(bound, kind) = max (anc_bounds)
return (bound, 'MergedBound')
def function_limit_bound (fname, split):
p = functions[fname].as_problem (problem.Problem)
p.do_analysis ()
cuts = [n for n in p.loop_body (split)
if p.nodes[n].kind == 'Call'
if function_limit (p.nodes[n].fname) != None]
if not cuts:
return None
graph = p.mk_node_graph (p.loop_body (split))
# it is not possible to iterate the loop without visiting a bounded
# function. naively, this sets the limit to the sum of all the possible
# bounds, plus one because we can enter the loop a final time without
# visiting any function call site yet.
if logic.divides_loop (graph, set (cuts)):
fnames = set ([p.nodes[n].fname for n in cuts])
return (sum ([function_limit (f) for f in fnames]) + 1, 'FunctionLimit')
def loop_bound_difficulty_estimates (split, ctxt):
# various guesses at how hard the loop bounding problem is.
(p, hyps, addr_map) = get_call_ctxt_problem (split, ctxt, timing = False)
loop_id = p.loop_id (addr_map[split])
assert loop_id
# number of instructions in the loop
inst_node_ids = set (addr_map.itervalues ())
l_insts = [n for n in p.loop_body (loop_id) if n in inst_node_ids]
# number of instructions in the function
tag = p.node_tags[loop_id][0]
f_insts = [n for n in inst_node_ids if p.node_tags[n][0] == tag]
# number of instructions in the whole calling context
ctxt_insts = len (inst_node_ids)
# well, what else?
return (len (l_insts), len (f_insts), ctxt_insts)
def load_timing ():
f = open ('%s/LoopBounds.txt' % target_objects.target_dir)
timing = {}
loop_time = 0.0
ext_time = 0.0
for line in f:
bits = line.split ()
if not (bits and 'Timing' in bits[0]):
continue
if bits[0] == 'LoopBoundTiming':
(n, ext_ctxt) = parse_ctxt (bits, 1)
assert n == len (bits) - 1
time = float (bits[n])
ctxt = ext_ctxt[:-1]
split = ext_ctxt[-1]
timing[(split, tuple(ctxt))] = time
loop_time += time
elif bits[0] == 'ExtraTiming':
time = float (bits[-1])
ext_time += time
f.close ()
f = open ('%s/time' % target_objects.target_dir)
[l] = [l for l in f if '(wall clock)' in l]
f.close ()
tot_time_str = l.split ()[-1]
tot_time = sum ([float(s) * (60 ** i)
for (i, s) in enumerate (reversed (tot_time_str.split(':')))])
return (loop_time, ext_time, tot_time, timing)
def mk_timing_metrics ():
if not known_bounds:
load_bounds ()
probs = [(split_bin_addr, tuple (call_ctxt), bound)
for (split_bin_addr, known) in known_bounds.iteritems ()
if type (split_bin_addr) == int
for (call_ctxt, h, prev_bounds, bound) in known]
probs = set (probs)
data = [(split, ctxt, bound,
loop_bound_difficulty_estimates (split, list (ctxt)))
for (split, ctxt, bound) in probs]
return data
# sigh, this is so much work.
bound_kind_nums = {
'FunctionLimit': 2,
'NaiveBinSearch': 3,
'InductiveBinSearch': 4,
'FromC': 5,
'MergedBound': 6,
}
gnuplot_colours = [
"dark-red", "dark-blue", "dark-green", "dark-grey",
"dark-orange", "dark-magenta", "dark-cyan"]
def save_timing_metrics (num):
(loop_time, ext_time, tot_time, timing) = load_timing ()
col = gnuplot_colours[num]
from target import short_name
time_ests = mk_timing_metrics ()
import os
f = open ('%s/LoopTimingMetrics.txt' % target_objects.target_dir, 'w')
f.write ('"%s"\n' % short_name)
for (split, ctxt, bound, ests) in time_ests:
time = timing[(split, tuple (ctxt))]
if bound == None:
bdata = "1000000 7"
else:
bdata = '%d %d' % (bound[0], bound_kind_nums[bound[1]])
(l_i, f_i, ct_i) = ests
f.write ('%s %s %s %s %s %r %s\n' % (short_name, l_i, f_i, ct_i,
bdata, col, time))
f.close ()
def get_loop_heads (fun):
if not fun.entry:
return []
p = fun.as_problem (problem.Problem)
p.do_loop_analysis ()
loops = set ()
for h in p.loop_heads ():
# any address in the loop will do. pick the smallest one
addr = min ([n for n in p.loop_body (h) if trace_refute.is_addr (n)])
loops.add ((addr, fun.name, problem.has_inner_loop (p, h)))
return list (loops)
def get_all_loop_heads ():
loops = set ()
abort_funs = set ()
for f in all_asm_functions ():
try:
loops.update (get_loop_heads (functions[f]))
except problem.Abort, e:
abort_funs.add (f)
if abort_funs:
trace ('Cannot analyse loops in: %s' % ', '.join (abort_funs))
return loops
def get_complex_loops ():
return [(loop, name) for (loop, name, compl) in get_all_loop_heads ()
if compl]
def search_all_loops ():
all_loops = get_all_loop_heads ()
for (loop, _, _) in all_loops:
get_bound_super_ctxt (loop, [])
main = search_all_loops
if __name__ == '__main__':
import sys
args = target_objects.load_target_args ()
if args == ['search']:
search_all_loops ()
elif args[:1] == ['metrics']:
num = args[1:].index (str (target_objects.target_dir))
save_timing_metrics (num)