diff --git a/cppwg/input/class_info.py b/cppwg/input/class_info.py index 01651f9..deed989 100644 --- a/cppwg/input/class_info.py +++ b/cppwg/input/class_info.py @@ -1,8 +1,10 @@ """Class information structure.""" import logging +import re from typing import Any, Dict, List, Optional +from pygccxml.declarations.matchers import access_type_matcher_t from pygccxml.declarations.runtime_errors import declaration_not_found_t from cppwg.input.cpp_type_info import CppTypeInfo @@ -48,7 +50,41 @@ def is_child_of(self, other: "ClassInfo") -> bool: # noqa: F821 return False if not other.decls: return False - return any(base in other.decls for base in self.base_decls) + return any(decl in other.decls for decl in self.base_decls) + + def requires(self, other: "ClassInfo") -> bool: # noqa: F821 + """ + Check if the specified class is used in method signatures of this class. + + Parameters + ---------- + other : ClassInfo + The specified class to check. + + Returns + ------- + bool + True if the specified class is used in method signatures of this class. + """ + if not self.decls: + return False + + query = access_type_matcher_t("public") + name_regex = re.compile(r"\b" + re.escape(other.name) + r"\b") + + for class_decl in self.decls: + method_decls = class_decl.member_functions(function=query, allow_empty=True) + for method_decl in method_decls: + for arg_type in method_decl.argument_types: + if name_regex.search(arg_type.decl_string): + return True + + ctor_decls = class_decl.constructors(function=query, allow_empty=True) + for ctor_decl in ctor_decls: + for arg_type in ctor_decl.argument_types: + if name_regex.search(arg_type.decl_string): + return True + return False def update_from_ns(self, ns: "namespace_t") -> None: # noqa: F821 """ @@ -70,22 +106,22 @@ def update_from_ns(self, ns: "namespace_t") -> None: # noqa: F821 self.decls = [] for class_cpp_name in self.cpp_names: - decl_name = class_cpp_name.replace(" ", "") # e.g. Foo<2,2> + class_name = class_cpp_name.replace(" ", "") # e.g. Foo<2,2> try: - class_decl = ns.class_(decl_name) + class_decl = ns.class_(class_name) except declaration_not_found_t as e1: if ( self.template_signature is None or "=" not in self.template_signature ): - logger.error(f"Could not find declaration for class {decl_name}") + logger.error(f"Could not find declaration for class {class_name}") raise e1 # If class has default args, try to compress the template signature logger.warning( - f"Could not find declaration for class {decl_name}: trying for a partial match." + f"Could not find declaration for class {class_name}: trying for a partial match." ) # Try to find the class without default template args @@ -97,16 +133,16 @@ def update_from_ns(self, ns: "namespace_t") -> None: # noqa: F821 pos = i break - decl_name = ",".join(decl_name.split(",")[0:pos]) + " >" + class_name = ",".join(class_name.split(",")[0:pos]) + " >" try: - class_decl = ns.class_(decl_name) + class_decl = ns.class_(class_name) except declaration_not_found_t as e2: - logger.error(f"Could not find declaration for class {decl_name}") + logger.error(f"Could not find declaration for class {class_name}") raise e2 - logger.info(f"Found {decl_name}") + logger.info(f"Found {class_name}") self.decls.append(class_decl) diff --git a/cppwg/input/module_info.py b/cppwg/input/module_info.py index 5ab5a85..b5509cd 100644 --- a/cppwg/input/module_info.py +++ b/cppwg/input/module_info.py @@ -77,7 +77,7 @@ def is_decl_in_source_path(self, decl: "declaration_t") -> bool: # noqa: F821 return False def sort_classes(self) -> None: - """Sort the class info collection in inheritance order.""" + """Sort the class info collection in order of dependence.""" self.class_info_collection.sort(key=lambda x: x.name) order_changed = True @@ -94,6 +94,10 @@ def sort_classes(self) -> None: for j in range(i + 1, n): cls_j = self.class_info_collection[j] if cls_i.is_child_of(cls_j): + # sort by inheritance + ii = j + elif cls_i.requires(cls_j) and not cls_j.requires(cls_i): + # sort by dependence, ignoring forward declaration cycles ii = j elif cls_j.is_child_of(cls_i): child_pos.append(j) diff --git a/cppwg/writers/class_writer.py b/cppwg/writers/class_writer.py index cf9d22b..d40d561 100644 --- a/cppwg/writers/class_writer.py +++ b/cppwg/writers/class_writer.py @@ -4,7 +4,8 @@ import os from typing import Dict, List -from pygccxml import declarations +from pygccxml.declarations import type_traits_classes +from pygccxml.declarations.matchers import access_type_matcher_t from cppwg.utils.constants import ( CPPWG_CLASS_OVERRIDE_SUFFIX, @@ -258,7 +259,7 @@ def write(self, work_dir: str) -> None: # struct Foo{ # enum Value{A, B, C}; # }; - if declarations.is_struct(class_decl): + if type_traits_classes.is_struct(class_decl): enums = class_decl.enumerations(allow_empty=True) if len(enums) == 1: @@ -326,7 +327,7 @@ def write(self, work_dir: str) -> None: self.cpp_string += class_definition_template.format(**class_definition_dict) # Add public constructors - query = declarations.access_type_matcher_t("public") + query = access_type_matcher_t("public") for constructor in class_decl.constructors( function=query, allow_empty=True ): @@ -339,7 +340,7 @@ def write(self, work_dir: str) -> None: self.cpp_string += constructor_writer.generate_wrapper() # Add public member functions - query = declarations.access_type_matcher_t("public") + query = access_type_matcher_t("public") for member_function in class_decl.member_functions( function=query, allow_empty=True ): diff --git a/cppwg/writers/constructor_writer.py b/cppwg/writers/constructor_writer.py index d8b4281..f57601a 100644 --- a/cppwg/writers/constructor_writer.py +++ b/cppwg/writers/constructor_writer.py @@ -3,7 +3,7 @@ import re from typing import Dict -from pygccxml import declarations +from pygccxml.declarations import type_traits, type_traits_classes from cppwg.writers.base_writer import CppBaseWrapperWriter @@ -87,9 +87,9 @@ def exclude(self) -> bool: if self.ctor_decl.parent != self.class_decl: return True - # Exclude default copy constructors e.g. Foo::Foo(Foo const & foo) + # Exclude compiler-added copy constructors e.g. Foo::Foo(Foo const & foo) if ( - declarations.is_copy_constructor(self.ctor_decl) + type_traits_classes.is_copy_constructor(self.ctor_decl) and self.ctor_decl.is_artificial ): return True @@ -200,11 +200,8 @@ def generate_wrapper(self) -> str: # `Foo(std::vector laminas = std::vector{})` # which generates `py::arg("laminas") = std::vector{}` if default_value.replace(" ", "") == "{}": - default_value = arg.decl_type.decl_string + " {}" - - # Remove const keyword - default_value = re.sub(r"\bconst\b", "", default_value) - default_value = default_value.replace(" ", " ") + decl_type = type_traits.remove_const(arg.decl_type) + default_value = decl_type.decl_string + " {}" keyword_args += f" = {default_value}" diff --git a/cppwg/writers/method_writer.py b/cppwg/writers/method_writer.py index 6309c2b..2f7e87f 100644 --- a/cppwg/writers/method_writer.py +++ b/cppwg/writers/method_writer.py @@ -3,7 +3,7 @@ import re from typing import Dict -from pygccxml import declarations +from pygccxml.declarations import type_traits from cppwg.writers.base_writer import CppBaseWrapperWriter @@ -129,7 +129,7 @@ def generate_wrapper(self) -> str: # Pybind11 def type e.g. "_static" for def_static() def_adorn = "" if self.method_decl.has_static: - def_adorn += "_static" + def_adorn = "_static" # How to point to class if self.method_decl.has_static: @@ -178,12 +178,12 @@ def generate_wrapper(self) -> str: # Call policy, e.g. "py::return_value_policy::reference" call_policy = "" - if declarations.is_pointer(self.method_decl.return_type): + if type_traits.is_pointer(self.method_decl.return_type): ptr_policy = self.class_info.hierarchy_attribute("pointer_call_policy") if ptr_policy: call_policy = f", py::return_value_policy::{ptr_policy}" - elif declarations.is_reference(self.method_decl.return_type): + elif type_traits.is_reference(self.method_decl.return_type): ref_policy = self.class_info.hierarchy_attribute("reference_call_policy") if ref_policy: call_policy = f", py::return_value_policy::{ref_policy}"