From dfa7a166daaa0e9066e027f5dc00323c92855d47 Mon Sep 17 00:00:00 2001 From: Kwabena N Amponsah Date: Sat, 21 Sep 2024 23:04:53 +0000 Subject: [PATCH] #12 Refactor add_class_decls --- cppwg/generators.py | 66 +++++------------------------------ cppwg/input/class_info.py | 71 ++++++++++++++++++++++++++++++++++++++ cppwg/input/module_info.py | 31 +++++++++++------ 3 files changed, 100 insertions(+), 68 deletions(-) diff --git a/cppwg/generators.py b/cppwg/generators.py index f1b4191..1b3305b 100644 --- a/cppwg/generators.py +++ b/cppwg/generators.py @@ -10,7 +10,6 @@ from typing import List, Optional import pygccxml -from pygccxml.declarations.runtime_errors import declaration_not_found_t from cppwg.input.class_info import CppClassInfo from cppwg.input.free_function_info import CppFreeFunctionInfo @@ -176,6 +175,8 @@ def collect_source_hpp_files(self) -> None: patterns e.g. "*.hpp". Skip the wrapper root and wrappers to avoid pollution. """ + logger = logging.getLogger() + for root, _, filenames in os.walk(self.source_root, followlinks=True): for pattern in self.package_info.source_hpp_patterns: for filename in fnmatch.filter(filenames, pattern): @@ -194,7 +195,7 @@ def collect_source_hpp_files(self) -> None: # Check if any source files were found if not self.package_info.source_hpp_files: - logging.error(f"No header files found in source root: {self.source_root}") + logger.error(f"No header files found in source root: {self.source_root}") raise FileNotFoundError() def extract_templates_from_source(self) -> None: @@ -210,6 +211,8 @@ def extract_templates_from_source(self) -> None: def log_unknown_classes(self) -> None: """Get unwrapped classes.""" + logger = logging.getLogger() + all_class_decls = self.source_ns.classes(allow_empty=True) seen_class_names = set() @@ -226,7 +229,7 @@ def log_unknown_classes(self) -> None: ): seen_class_names.add(decl.name) seen_class_names.add(decl.name.split("<")[0].strip()) - logging.info( + logger.info( f"Unknown class {decl.name} from {decl.location.file_name}:{decl.location.line}" ) @@ -238,7 +241,7 @@ def log_unknown_classes(self) -> None: for _, class_name, _ in class_list: if class_name not in seen_class_names: seen_class_names.add(class_name) - logging.info(f"Unknown class {class_name} from {hpp_file_path}") + logger.info(f"Unknown class {class_name} from {hpp_file_path}") def map_classes_to_hpp_files(self) -> None: """ @@ -314,60 +317,7 @@ def add_class_decls(self) -> None: declarations found by pygccxml in the C++ source code. """ for module_info in self.package_info.module_info_collection: - for class_info in module_info.class_info_collection: - # Skip excluded classes - if class_info.excluded: - continue - - class_info.decls: List["class_t"] = [] # noqa: F821 - - for class_cpp_name in class_info.cpp_names: - decl_name = class_cpp_name.replace(" ", "") # e.g. Foo<2,2> - - try: - class_decl = self.source_ns.class_(decl_name) - - except declaration_not_found_t as e1: - if ( - class_info.template_signature is None - or "=" not in class_info.template_signature - ): - logging.error( - f"Could not find declaration for class {decl_name}" - ) - raise e1 - - # If class has default args, try to compress the template signature - logging.warning( - f"Could not find declaration for class {decl_name}: trying for a partial match." - ) - - # Try to find the class without default template args - # e.g. for template class Foo {}; - # Look for Foo<2> instead of Foo<2,2> - pos = 0 - for i, s in enumerate(class_info.template_signature.split(",")): - if "=" in s: - pos = i - break - - decl_name = ",".join(decl_name.split(",")[0:pos]) + " >" - - try: - class_decl = self.source_ns.class_(decl_name) - - except declaration_not_found_t as e2: - logging.error( - f"Could not find declaration for class {decl_name}" - ) - raise e2 - - logging.info(f"Found {decl_name}") - - class_info.decls.append(class_decl) - - # Sort the class info collection in inheritance order - module_info.sort_classes() + module_info.update_from_ns(self.source_ns) def add_discovered_free_functions(self) -> None: """ diff --git a/cppwg/input/class_info.py b/cppwg/input/class_info.py index af4c566..ffd686f 100644 --- a/cppwg/input/class_info.py +++ b/cppwg/input/class_info.py @@ -1,7 +1,10 @@ """Class information structure.""" +import logging from typing import Any, Dict, List, Optional +from pygccxml.declarations.runtime_errors import declaration_not_found_t + from cppwg.input.cpp_type_info import CppTypeInfo @@ -15,6 +18,8 @@ class CppClassInfo(CppTypeInfo): The C++ names of the class e.g. ["Foo<2,2>", "Foo<3,3>"] py_names : List[str] The Python names of the class e.g. ["Foo2_2", "Foo3_3"] + decls : pygccxml.declarations.declaration_t + Declarations for this type's base class, one per template instantiation """ def __init__(self, name: str, class_config: Optional[Dict[str, Any]] = None): @@ -23,6 +28,72 @@ def __init__(self, name: str, class_config: Optional[Dict[str, Any]] = None): self.cpp_names: List[str] = None self.py_names: List[str] = None + self.base_decls: Optional[List["declaration_t"]] = None # noqa: F821 + + def update_from_ns(self, ns: "namespace_t") -> None: # noqa: F821 + """ + Update the class information from the source namespace. + + Adds the class declarations and base class declarations. + + Parameters + ---------- + ns : pygccxml.declarations.namespace_t + The source namespace + """ + logger = logging.getLogger() + + # Skip excluded classes + if self.excluded: + return + + self.decls = [] + + for class_cpp_name in self.cpp_names: + decl_name = class_cpp_name.replace(" ", "") # e.g. Foo<2,2> + + try: + class_decl = ns.class_(decl_name) + + except declaration_not_found_t as e1: + if ( + self.template_signature is None + or "=" not in self.template_signature + ): + logger.error(f"Could not find declaration for class {decl_name}") + raise e1 + + # If class has default args, try to compress the template signature + logger.warning( + f"Could not find declaration for class {decl_name}: trying for a partial match." + ) + + # Try to find the class without default template args + # e.g. for template class Foo {}; + # Look for Foo<2> instead of Foo<2,2> + pos = 0 + for i, s in enumerate(self.template_signature.split(",")): + if "=" in s: + pos = i + break + + decl_name = ",".join(decl_name.split(",")[0:pos]) + " >" + + try: + class_decl = ns.class_(decl_name) + + except declaration_not_found_t as e2: + logger.error(f"Could not find declaration for class {decl_name}") + raise e2 + + logger.info(f"Found {decl_name}") + + self.decls.append(class_decl) + + # Update the base class declarations + self.base_decls = [ + base.related_class for decl in self.decls for base in decl.bases + ] def update_py_names(self) -> None: """ diff --git a/cppwg/input/module_info.py b/cppwg/input/module_info.py index d658bb0..0fe9edd 100644 --- a/cppwg/input/module_info.py +++ b/cppwg/input/module_info.py @@ -89,21 +89,32 @@ def compare(class_info_0: "ClassInfo", class_info_1: "ClassInfo"): # noqa: F821 if class_info_1.decls is None: return -1 - # Get the base classes for each class - bases_0 = [ - base.related_class for decl in class_info_0.decls for base in decl.bases - ] - bases_1 = [ - base.related_class for decl in class_info_1.decls for base in decl.bases - ] - # 1 if class_0 is a child of class_1 - child_0 = int(any(base in class_info_1.decls for base in bases_0)) + child_0 = int( + any(base in class_info_1.decls for base in class_info_0.base_decls) + ) # 1 if class_1 is a child of class 0 - child_1 = int(any(base in class_info_0.decls for base in bases_1)) + child_1 = int( + any(base in class_info_0.decls for base in class_info_1.base_decls) + ) return child_0 - child_1 self.class_info_collection.sort(key=lambda x: x.name) self.class_info_collection.sort(key=cmp_to_key(compare)) + + def update_from_ns(self, ns: "namespace_t") -> None: # noqa: F821 + """ + Update class info objects with information from the source namespace. + + Parameters + ---------- + ns : pygccxml.declarations.namespace_t + The source namespace + """ + for class_info in self.class_info_collection: + class_info.update_from_ns(ns) + + # Sort the class info collection in inheritance order + self.sort_classes()