Skip to content

Commit

Permalink
#12 Refactor add_class_decls
Browse files Browse the repository at this point in the history
  • Loading branch information
kwabenantim committed Sep 21, 2024
1 parent 1ee9a5b commit dfa7a16
Show file tree
Hide file tree
Showing 3 changed files with 100 additions and 68 deletions.
66 changes: 8 additions & 58 deletions cppwg/generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from typing import List, Optional

import pygccxml
from pygccxml.declarations.runtime_errors import declaration_not_found_t

from cppwg.input.class_info import CppClassInfo
from cppwg.input.free_function_info import CppFreeFunctionInfo
Expand Down Expand Up @@ -176,6 +175,8 @@ def collect_source_hpp_files(self) -> None:
patterns e.g. "*.hpp". Skip the wrapper root and wrappers to
avoid pollution.
"""
logger = logging.getLogger()

for root, _, filenames in os.walk(self.source_root, followlinks=True):
for pattern in self.package_info.source_hpp_patterns:
for filename in fnmatch.filter(filenames, pattern):
Expand All @@ -194,7 +195,7 @@ def collect_source_hpp_files(self) -> None:

# Check if any source files were found
if not self.package_info.source_hpp_files:
logging.error(f"No header files found in source root: {self.source_root}")
logger.error(f"No header files found in source root: {self.source_root}")
raise FileNotFoundError()

def extract_templates_from_source(self) -> None:
Expand All @@ -210,6 +211,8 @@ def extract_templates_from_source(self) -> None:

def log_unknown_classes(self) -> None:
"""Get unwrapped classes."""
logger = logging.getLogger()

all_class_decls = self.source_ns.classes(allow_empty=True)

seen_class_names = set()
Expand All @@ -226,7 +229,7 @@ def log_unknown_classes(self) -> None:
):
seen_class_names.add(decl.name)
seen_class_names.add(decl.name.split("<")[0].strip())
logging.info(
logger.info(
f"Unknown class {decl.name} from {decl.location.file_name}:{decl.location.line}"
)

Expand All @@ -238,7 +241,7 @@ def log_unknown_classes(self) -> None:
for _, class_name, _ in class_list:
if class_name not in seen_class_names:
seen_class_names.add(class_name)
logging.info(f"Unknown class {class_name} from {hpp_file_path}")
logger.info(f"Unknown class {class_name} from {hpp_file_path}")

def map_classes_to_hpp_files(self) -> None:
"""
Expand Down Expand Up @@ -314,60 +317,7 @@ def add_class_decls(self) -> None:
declarations found by pygccxml in the C++ source code.
"""
for module_info in self.package_info.module_info_collection:
for class_info in module_info.class_info_collection:
# Skip excluded classes
if class_info.excluded:
continue

class_info.decls: List["class_t"] = [] # noqa: F821

for class_cpp_name in class_info.cpp_names:
decl_name = class_cpp_name.replace(" ", "") # e.g. Foo<2,2>

try:
class_decl = self.source_ns.class_(decl_name)

except declaration_not_found_t as e1:
if (
class_info.template_signature is None
or "=" not in class_info.template_signature
):
logging.error(
f"Could not find declaration for class {decl_name}"
)
raise e1

# If class has default args, try to compress the template signature
logging.warning(
f"Could not find declaration for class {decl_name}: trying for a partial match."
)

# Try to find the class without default template args
# e.g. for template <int A, int B=A> class Foo {};
# Look for Foo<2> instead of Foo<2,2>
pos = 0
for i, s in enumerate(class_info.template_signature.split(",")):
if "=" in s:
pos = i
break

decl_name = ",".join(decl_name.split(",")[0:pos]) + " >"

try:
class_decl = self.source_ns.class_(decl_name)

except declaration_not_found_t as e2:
logging.error(
f"Could not find declaration for class {decl_name}"
)
raise e2

logging.info(f"Found {decl_name}")

class_info.decls.append(class_decl)

# Sort the class info collection in inheritance order
module_info.sort_classes()
module_info.update_from_ns(self.source_ns)

def add_discovered_free_functions(self) -> None:
"""
Expand Down
71 changes: 71 additions & 0 deletions cppwg/input/class_info.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
"""Class information structure."""

import logging
from typing import Any, Dict, List, Optional

from pygccxml.declarations.runtime_errors import declaration_not_found_t

from cppwg.input.cpp_type_info import CppTypeInfo


Expand All @@ -15,6 +18,8 @@ class CppClassInfo(CppTypeInfo):
The C++ names of the class e.g. ["Foo<2,2>", "Foo<3,3>"]
py_names : List[str]
The Python names of the class e.g. ["Foo2_2", "Foo3_3"]
decls : pygccxml.declarations.declaration_t
Declarations for this type's base class, one per template instantiation
"""

def __init__(self, name: str, class_config: Optional[Dict[str, Any]] = None):
Expand All @@ -23,6 +28,72 @@ def __init__(self, name: str, class_config: Optional[Dict[str, Any]] = None):

self.cpp_names: List[str] = None
self.py_names: List[str] = None
self.base_decls: Optional[List["declaration_t"]] = None # noqa: F821

def update_from_ns(self, ns: "namespace_t") -> None: # noqa: F821
"""
Update the class information from the source namespace.
Adds the class declarations and base class declarations.
Parameters
----------
ns : pygccxml.declarations.namespace_t
The source namespace
"""
logger = logging.getLogger()

# Skip excluded classes
if self.excluded:
return

self.decls = []

for class_cpp_name in self.cpp_names:
decl_name = class_cpp_name.replace(" ", "") # e.g. Foo<2,2>

try:
class_decl = ns.class_(decl_name)

except declaration_not_found_t as e1:
if (
self.template_signature is None
or "=" not in self.template_signature
):
logger.error(f"Could not find declaration for class {decl_name}")
raise e1

# If class has default args, try to compress the template signature
logger.warning(
f"Could not find declaration for class {decl_name}: trying for a partial match."
)

# Try to find the class without default template args
# e.g. for template <int A, int B=A> class Foo {};
# Look for Foo<2> instead of Foo<2,2>
pos = 0
for i, s in enumerate(self.template_signature.split(",")):
if "=" in s:
pos = i
break

decl_name = ",".join(decl_name.split(",")[0:pos]) + " >"

try:
class_decl = ns.class_(decl_name)

except declaration_not_found_t as e2:
logger.error(f"Could not find declaration for class {decl_name}")
raise e2

logger.info(f"Found {decl_name}")

self.decls.append(class_decl)

# Update the base class declarations
self.base_decls = [
base.related_class for decl in self.decls for base in decl.bases
]

def update_py_names(self) -> None:
"""
Expand Down
31 changes: 21 additions & 10 deletions cppwg/input/module_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,21 +89,32 @@ def compare(class_info_0: "ClassInfo", class_info_1: "ClassInfo"): # noqa: F821
if class_info_1.decls is None:
return -1

# Get the base classes for each class
bases_0 = [
base.related_class for decl in class_info_0.decls for base in decl.bases
]
bases_1 = [
base.related_class for decl in class_info_1.decls for base in decl.bases
]

# 1 if class_0 is a child of class_1
child_0 = int(any(base in class_info_1.decls for base in bases_0))
child_0 = int(
any(base in class_info_1.decls for base in class_info_0.base_decls)
)

# 1 if class_1 is a child of class 0
child_1 = int(any(base in class_info_0.decls for base in bases_1))
child_1 = int(
any(base in class_info_0.decls for base in class_info_1.base_decls)
)

return child_0 - child_1

self.class_info_collection.sort(key=lambda x: x.name)
self.class_info_collection.sort(key=cmp_to_key(compare))

def update_from_ns(self, ns: "namespace_t") -> None: # noqa: F821
"""
Update class info objects with information from the source namespace.
Parameters
----------
ns : pygccxml.declarations.namespace_t
The source namespace
"""
for class_info in self.class_info_collection:
class_info.update_from_ns(ns)

# Sort the class info collection in inheritance order
self.sort_classes()

0 comments on commit dfa7a16

Please sign in to comment.