Skip to content

Commit

Permalink
#12 Sort classes by dependence
Browse files Browse the repository at this point in the history
  • Loading branch information
kwabenantim committed Sep 22, 2024
1 parent 7cfce4b commit bbb7184
Show file tree
Hide file tree
Showing 5 changed files with 64 additions and 26 deletions.
54 changes: 45 additions & 9 deletions cppwg/input/class_info.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
"""Class information structure."""

import logging
import re
from typing import Any, Dict, List, Optional

from pygccxml.declarations.matchers import access_type_matcher_t
from pygccxml.declarations.runtime_errors import declaration_not_found_t

from cppwg.input.cpp_type_info import CppTypeInfo
Expand Down Expand Up @@ -48,7 +50,41 @@ def is_child_of(self, other: "ClassInfo") -> bool: # noqa: F821
return False
if not other.decls:
return False
return any(base in other.decls for base in self.base_decls)
return any(decl in other.decls for decl in self.base_decls)

def requires(self, other: "ClassInfo") -> bool: # noqa: F821
"""
Check if the specified class is used in method signatures of this class.
Parameters
----------
other : ClassInfo
The specified class to check.
Returns
-------
bool
True if the specified class is used in method signatures of this class.
"""
if not self.decls:
return False

query = access_type_matcher_t("public")
name_regex = re.compile(r"\b" + re.escape(other.name) + r"\b")

for class_decl in self.decls:
method_decls = class_decl.member_functions(function=query, allow_empty=True)
for method_decl in method_decls:
for arg_type in method_decl.argument_types:
if name_regex.search(arg_type.decl_string):
return True

ctor_decls = class_decl.constructors(function=query, allow_empty=True)
for ctor_decl in ctor_decls:
for arg_type in ctor_decl.argument_types:
if name_regex.search(arg_type.decl_string):
return True
return False

def update_from_ns(self, ns: "namespace_t") -> None: # noqa: F821
"""
Expand All @@ -70,22 +106,22 @@ def update_from_ns(self, ns: "namespace_t") -> None: # noqa: F821
self.decls = []

for class_cpp_name in self.cpp_names:
decl_name = class_cpp_name.replace(" ", "") # e.g. Foo<2,2>
class_name = class_cpp_name.replace(" ", "") # e.g. Foo<2,2>

try:
class_decl = ns.class_(decl_name)
class_decl = ns.class_(class_name)

except declaration_not_found_t as e1:
if (
self.template_signature is None
or "=" not in self.template_signature
):
logger.error(f"Could not find declaration for class {decl_name}")
logger.error(f"Could not find declaration for class {class_name}")
raise e1

# If class has default args, try to compress the template signature
logger.warning(
f"Could not find declaration for class {decl_name}: trying for a partial match."
f"Could not find declaration for class {class_name}: trying for a partial match."
)

# Try to find the class without default template args
Expand All @@ -97,16 +133,16 @@ def update_from_ns(self, ns: "namespace_t") -> None: # noqa: F821
pos = i
break

decl_name = ",".join(decl_name.split(",")[0:pos]) + " >"
class_name = ",".join(class_name.split(",")[0:pos]) + " >"

try:
class_decl = ns.class_(decl_name)
class_decl = ns.class_(class_name)

except declaration_not_found_t as e2:
logger.error(f"Could not find declaration for class {decl_name}")
logger.error(f"Could not find declaration for class {class_name}")
raise e2

logger.info(f"Found {decl_name}")
logger.info(f"Found {class_name}")

self.decls.append(class_decl)

Expand Down
6 changes: 5 additions & 1 deletion cppwg/input/module_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def is_decl_in_source_path(self, decl: "declaration_t") -> bool: # noqa: F821
return False

def sort_classes(self) -> None:
"""Sort the class info collection in inheritance order."""
"""Sort the class info collection in order of dependence."""
self.class_info_collection.sort(key=lambda x: x.name)

order_changed = True
Expand All @@ -94,6 +94,10 @@ def sort_classes(self) -> None:
for j in range(i + 1, n):
cls_j = self.class_info_collection[j]
if cls_i.is_child_of(cls_j):
# sort by inheritance
ii = j
elif cls_i.requires(cls_j) and not cls_j.requires(cls_i):
# sort by dependence, ignoring forward declaration cycles
ii = j
elif cls_j.is_child_of(cls_i):
child_pos.append(j)
Expand Down
9 changes: 5 additions & 4 deletions cppwg/writers/class_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
import os
from typing import Dict, List

from pygccxml import declarations
from pygccxml.declarations import type_traits_classes
from pygccxml.declarations.matchers import access_type_matcher_t

from cppwg.utils.constants import (
CPPWG_CLASS_OVERRIDE_SUFFIX,
Expand Down Expand Up @@ -258,7 +259,7 @@ def write(self, work_dir: str) -> None:
# struct Foo{
# enum Value{A, B, C};
# };
if declarations.is_struct(class_decl):
if type_traits_classes.is_struct(class_decl):
enums = class_decl.enumerations(allow_empty=True)

if len(enums) == 1:
Expand Down Expand Up @@ -326,7 +327,7 @@ def write(self, work_dir: str) -> None:
self.cpp_string += class_definition_template.format(**class_definition_dict)

# Add public constructors
query = declarations.access_type_matcher_t("public")
query = access_type_matcher_t("public")
for constructor in class_decl.constructors(
function=query, allow_empty=True
):
Expand All @@ -339,7 +340,7 @@ def write(self, work_dir: str) -> None:
self.cpp_string += constructor_writer.generate_wrapper()

# Add public member functions
query = declarations.access_type_matcher_t("public")
query = access_type_matcher_t("public")
for member_function in class_decl.member_functions(
function=query, allow_empty=True
):
Expand Down
13 changes: 5 additions & 8 deletions cppwg/writers/constructor_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import re
from typing import Dict

from pygccxml import declarations
from pygccxml.declarations import type_traits, type_traits_classes

from cppwg.writers.base_writer import CppBaseWrapperWriter

Expand Down Expand Up @@ -87,9 +87,9 @@ def exclude(self) -> bool:
if self.ctor_decl.parent != self.class_decl:
return True

# Exclude default copy constructors e.g. Foo::Foo(Foo const & foo)
# Exclude compiler-added copy constructors e.g. Foo::Foo(Foo const & foo)
if (
declarations.is_copy_constructor(self.ctor_decl)
type_traits_classes.is_copy_constructor(self.ctor_decl)
and self.ctor_decl.is_artificial
):
return True
Expand Down Expand Up @@ -200,11 +200,8 @@ def generate_wrapper(self) -> str:
# `Foo(std::vector<Bar*> laminas = std::vector<Bar*>{})`
# which generates `py::arg("laminas") = std::vector<Bar*>{}`
if default_value.replace(" ", "") == "{}":
default_value = arg.decl_type.decl_string + " {}"

# Remove const keyword
default_value = re.sub(r"\bconst\b", "", default_value)
default_value = default_value.replace(" ", " ")
decl_type = type_traits.remove_const(arg.decl_type)
default_value = decl_type.decl_string + " {}"

keyword_args += f" = {default_value}"

Expand Down
8 changes: 4 additions & 4 deletions cppwg/writers/method_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import re
from typing import Dict

from pygccxml import declarations
from pygccxml.declarations import type_traits

from cppwg.writers.base_writer import CppBaseWrapperWriter

Expand Down Expand Up @@ -129,7 +129,7 @@ def generate_wrapper(self) -> str:
# Pybind11 def type e.g. "_static" for def_static()
def_adorn = ""
if self.method_decl.has_static:
def_adorn += "_static"
def_adorn = "_static"

# How to point to class
if self.method_decl.has_static:
Expand Down Expand Up @@ -178,12 +178,12 @@ def generate_wrapper(self) -> str:

# Call policy, e.g. "py::return_value_policy::reference"
call_policy = ""
if declarations.is_pointer(self.method_decl.return_type):
if type_traits.is_pointer(self.method_decl.return_type):
ptr_policy = self.class_info.hierarchy_attribute("pointer_call_policy")
if ptr_policy:
call_policy = f", py::return_value_policy::{ptr_policy}"

elif declarations.is_reference(self.method_decl.return_type):
elif type_traits.is_reference(self.method_decl.return_type):
ref_policy = self.class_info.hierarchy_attribute("reference_call_policy")
if ref_policy:
call_policy = f", py::return_value_policy::{ref_policy}"
Expand Down

0 comments on commit bbb7184

Please sign in to comment.