diff --git a/cppwg/generators.py b/cppwg/generators.py index 0191df3..e1a956a 100644 --- a/cppwg/generators.py +++ b/cppwg/generators.py @@ -231,26 +231,7 @@ def log_unknown_classes(self) -> None: seen_class_names.add(class_name) logger.info(f"Unknown class {class_name} from {hpp_file_path}") - def map_classes_to_hpp_files(self) -> None: - """ - Map each class to a header file. - - Attempt to map source file paths to each class, assuming the containing - file name is the class name. - """ - for module_info in self.package_info.module_info_collection: - for class_info in module_info.class_info_collection: - # Skip excluded classes - if class_info.excluded: - continue - for hpp_file_path in self.package_info.source_hpp_files: - hpp_file_name = os.path.basename(hpp_file_path) - if class_info.name == os.path.splitext(hpp_file_name)[0]: - class_info.source_file_full_path = hpp_file_path - if class_info.source_file is None: - class_info.source_file = hpp_file_name - - def parse_header_collection(self) -> None: + def parse_headers(self) -> None: """ Parse the hpp files to collect C++ declarations. @@ -279,19 +260,19 @@ def parse_package_info(self) -> None: # If no package info file exists, create a PackageInfo object with default settings self.package_info = PackageInfo("cppwg_package", self.source_root) - def update_from_ns(self) -> None: + def update_modules_from_ns(self) -> None: """ Update modules with information from the parsed source namespace. """ for module_info in self.package_info.module_info_collection: module_info.update_from_ns(self.source_ns) - def update_from_source(self) -> None: + def update_modules_from_source(self) -> None: """ Update modules with information from the source headers. """ for module_info in self.package_info.module_info_collection: - module_info.update_from_source() + module_info.update_from_source(self.package_info.source_hpp_files) def write_header_collection(self) -> None: """ @@ -326,20 +307,17 @@ def generate_wrapper(self) -> None: # Search for header files in the source root self.collect_source_hpp_files() - # Map each class to a header file - self.map_classes_to_hpp_files() - # Update modules with information from the source headers - self.update_from_source() + self.update_modules_from_source() - # Write the header collection to file + # Write the header collection file self.write_header_collection() - # Parse the headers with pygccxml and castxml - self.parse_header_collection() + # Parse the headers with pygccxml (+ castxml) + self.parse_headers() # Update modules with information from the parsed source namespace - self.update_from_ns() + self.update_modules_from_ns() # Log list of unknown classes in the source root self.log_unknown_classes() diff --git a/cppwg/input/class_info.py b/cppwg/input/class_info.py index 07b30c2..61587ae 100644 --- a/cppwg/input/class_info.py +++ b/cppwg/input/class_info.py @@ -1,6 +1,7 @@ """Class information structure.""" import logging +import os import re from typing import Any, Dict, List, Optional @@ -215,7 +216,7 @@ def update_from_ns(self, source_ns: "namespace_t") -> None: # noqa: F821 base.related_class for decl in self.decls for base in decl.bases ] - def update_from_source(self) -> None: + def update_from_source(self, source_files: List[str]) -> None: """ Update class with information from the source headers. """ @@ -223,7 +224,18 @@ def update_from_source(self) -> None: if self.excluded: return + # Map class to a source file, assuming the file name is the class name + for file_path in source_files: + file_name = os.path.basename(file_path) + if self.name == os.path.splitext(file_name)[0]: + self.source_file_full_path = file_path + if self.source_file is None: + self.source_file = file_name + + # Extract template args from the source file self.extract_templates_from_source() + + # Update the C++ and Python class names self.update_names() def update_py_names(self) -> None: diff --git a/cppwg/input/module_info.py b/cppwg/input/module_info.py index 4867b4c..f0b0b28 100644 --- a/cppwg/input/module_info.py +++ b/cppwg/input/module_info.py @@ -162,10 +162,9 @@ def update_from_ns(self, source_ns: "namespace_t") -> None: # noqa: F821 for ff_info in self.free_function_info_collection: ff_info.update_from_ns(source_ns) - def update_from_source(self) -> None: + def update_from_source(self, source_files: List[str]) -> None: """ Update module with information from the source headers. - """ for class_info in self.class_info_collection: - class_info.update_from_source() + class_info.update_from_source(source_files)