diff --git a/cppwg/input/module_info.py b/cppwg/input/module_info.py index 88bd064..fbbea44 100644 --- a/cppwg/input/module_info.py +++ b/cppwg/input/module_info.py @@ -85,6 +85,46 @@ def sort_classes(self) -> None: """ Sort the class info collection in order of dependence. """ + cache = dict() + + def compare(a: CppClassInfo, b: CppClassInfo) -> int: + """ + Compare two class info objects for dependence order. + + Parameters + ---------- + a : CppClassInfo + b : CppClassInfo + + Returns + ------- + int + -1 if a comes before b (b depends on a) + 0 if there is no dependence + 1 if a comes after b (a depends on b) + """ + order = cache.get((a, b), None) + if order is not None: + return order + + a_req_b = a.requires(b) + b_req_a = b.requires(a) + if a.is_child_of(b) or (a_req_b and not b_req_a): + # a comes after b (ignore cyclic dependencies) + cache[(a, b)] = 1 + cache[(b, a)] = -1 + return 1 + elif b.is_child_of(a) or (b_req_a and not a_req_b): + # a comes before b (ignore cyclic dependencies) + cache[(a, b)] = -1 + cache[(b, a)] = 1 + return -1 + + # Order doesn't matter + cache[(a, b)] = 0 + cache[(b, a)] = 0 + return 0 + self.class_info_collection.sort(key=lambda x: x.name) i = 0 @@ -96,13 +136,11 @@ def sort_classes(self) -> None: for j in range(i + 1, n): cls_j = self.class_info_collection[j] - i_req_j = cls_i.requires(cls_j) - j_req_i = cls_j.requires(cls_i) - if cls_i.is_child_of(cls_j) or (i_req_j and not j_req_i): - # Position cls_i after all classes it depends on, - # ignoring forward declaration cycles + order = compare(cls_i, cls_j) + if order == 1: + # Position cls_i after all classes it depends on ii = j - elif cls_j.is_child_of(cls_i) or (j_req_i and not i_req_j): + elif order == -1: # Collect positions of cls_i's dependents j_pos.append(j)