Skip to content

Commit

Permalink
#5 replace template names in default args
Browse files Browse the repository at this point in the history
  • Loading branch information
kwabenantim committed May 2, 2024
1 parent a96d153 commit 7efe7ea
Show file tree
Hide file tree
Showing 5 changed files with 84 additions and 54 deletions.
5 changes: 4 additions & 1 deletion cppwg/input/cpp_type_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,10 @@ class CppTypeInfo(BaseInfo):
The name override specified in config e.g. "CustomFoo" -> "Foo"
template_signature : str
The template signature of the type e.g. "<unsigned DIM_A, unsigned DIM_B = DIM_A>"
template_params : List[str]
List of template parameters e.g. ["DIM_A", "DIM_B"]
template_arg_lists : List[List[Any]]
List of template replacement arguments for the type e.g. [[2, 2], [3, 3]]
List of template replacement arguments e.g. [[2, 2], [3, 3]]
decls : pygccxml.declarations.declaration_t
The pygccxml declarations associated with this type, one per template arg if templated
"""
Expand All @@ -36,6 +38,7 @@ def __init__(self, name: str, type_config: Optional[Dict[str, Any]] = None):
self.source_file: Optional[str] = None
self.name_override: Optional[str] = None
self.template_signature: Optional[str] = None
self.template_params: Optional[List[str]] = None
self.template_arg_lists: Optional[List[List[Any]]] = None
self.decls: Optional[List["declaration_t"]] = None # noqa: F821

Expand Down
12 changes: 12 additions & 0 deletions cppwg/input/info_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,4 +134,16 @@ def extract_templates_from_source(self, feature_info: BaseInfo) -> None:
feature_info.template_signature = template_substitution[
"signature"
]

# Extract ["DIM_A", "DIM_B"] from "<unsigned A, unsigned DIM_B=DIM_A>"
template_params = []
for tp in template_substitution["signature"].split(","):
template_params.append(
tp.replace("<", "")
.replace(">", "")
.split(" ")[1]
.split("=")[0]
.strip()
)
feature_info.template_params = template_params
break
24 changes: 11 additions & 13 deletions cppwg/writers/class_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ def add_cpp_header(self, class_full_name: str, class_short_name: str) -> None:
)

def add_virtual_overrides(
self, class_decl: "class_t", short_class_name: str # noqa: F821
self, template_idx: int
) -> List["member_function_t"]: # noqa: F821
"""
Add virtual "trampoline" overrides for the class.
Expand All @@ -157,10 +157,8 @@ def add_virtual_overrides(
Parameters
----------
class_decl : class_t
The class declaration
short_class_name : str
The short name of the class e.g. Foo2_2
template_idx : int
The index of the template in the class info
Returns
-------
Expand All @@ -170,6 +168,8 @@ def add_virtual_overrides(
return_types: List[str] = [] # e.g. ["void", "unsigned int", "::Bar<2> *"]

# Collect all virtual methods and their return types
class_decl = self.class_info.decls[template_idx]

for member_function in class_decl.member_functions(allow_empty=True):
is_pure_virtual = member_function.virtuality == "pure virtual"
is_virtual = member_function.virtuality == "virtual"
Expand All @@ -192,13 +192,14 @@ def add_virtual_overrides(
self.cpp_string += "\n"

# Override virtual methods
short_name = self.class_info.short_names[template_idx]
if methods_needing_override:
# Add virtual override class, e.g.:
# class Foo_Overrides : public Foo {
# public:
# using Foo::Foo;
override_header_dict = {
"class_short_name": short_class_name,
"class_short_name": short_name,
"class_base_name": self.class_info.name,
}

Expand All @@ -217,10 +218,9 @@ def add_virtual_overrides(
for method in methods_needing_override:
method_writer = CppMethodWrapperWriter(
self.class_info,
template_idx,
method,
class_decl,
self.wrapper_templates,
short_class_name,
)
self.cpp_string += method_writer.generate_virtual_override_wrapper()

Expand Down Expand Up @@ -284,7 +284,7 @@ def write(self, work_dir: str) -> None:

# Find and define virtual function "trampoline" overrides
methods_needing_override: List["member_function_t"] = ( # noqa: F821
self.add_virtual_overrides(class_decl, short_name)
self.add_virtual_overrides(idx)
)

# Add the virtual "trampoline" overrides from "Foo_Overrides" to
Expand Down Expand Up @@ -331,10 +331,9 @@ def write(self, work_dir: str) -> None:
):
constructor_writer = CppConstructorWrapperWriter(
self.class_info,
idx,
constructor,
class_decl,
self.wrapper_templates,
short_name,
)
self.cpp_string += constructor_writer.generate_wrapper()

Expand All @@ -350,10 +349,9 @@ def write(self, work_dir: str) -> None:

method_writer = CppMethodWrapperWriter(
self.class_info,
idx,
member_function,
class_decl,
self.wrapper_templates,
short_name,
)
self.cpp_string += method_writer.generate_wrapper()

Expand Down
45 changes: 30 additions & 15 deletions cppwg/writers/constructor_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,7 @@
from typing import Dict, Optional

from pygccxml import declarations
from pygccxml.declarations.calldef_members import constructor_t
from pygccxml.declarations.class_declaration import class_t

from cppwg.input.class_info import CppClassInfo
from cppwg.writers.base_writer import CppBaseWrapperWriter


Expand All @@ -18,35 +15,46 @@ class CppConstructorWrapperWriter(CppBaseWrapperWriter):
----------
class_info : ClassInfo
The class information for the class containing the constructor
ctor_decl : constructor_t
template_idx: int
The index of the template in class_info
ctor_decl : pygccxml.declarations.constructor_t
The pygccxml declaration object for the constructor
class_decl : class_t
class_decl : pygccxml.declarations.class_t
The class declaration for the class containing the constructor
wrapper_templates : Dict[str, str]
String templates with placeholders for generating wrapper code
class_short_name : Optional[str]
The short name of the class e.g. 'Foo2_2'
template_params: Optional[List[str]]
The template params for the class e.g. ['DIM_A', 'DIM_B']
template_args: Optional[List[str]]
The template args for the class e.g. ['2', '2']
"""

def __init__(
self,
class_info: CppClassInfo,
ctor_decl: constructor_t,
class_decl: class_t,
class_info: "CppClassInfo", # noqa: F821
template_idx: int,
ctor_decl: "constructor_t", # noqa: F821
wrapper_templates: Dict[str, str],
class_short_name: Optional[str] = None,
) -> None:

super(CppConstructorWrapperWriter, self).__init__(wrapper_templates)

self.class_info: CppClassInfo = class_info
self.ctor_decl: constructor_t = ctor_decl
self.class_decl: class_t = class_decl
self.class_info: "CppClassInfo" = class_info # noqa: F821
self.ctor_decl: "constructor_t" = ctor_decl # noqa: F821
self.class_decl: "class_t" = class_info.decls[template_idx] # noqa: F821

self.class_short_name = class_short_name
self.class_short_name = class_info.short_names[template_idx]
if self.class_short_name is None:
self.class_short_name = self.class_decl.name

self.template_params = class_info.template_params

self.template_args = None
if class_info.template_arg_lists:
self.template_args = class_info.template_arg_lists[template_idx]

def exclusion_criteria(self) -> bool:
"""
Check if the constructor should be excluded from the wrapper code.
Expand Down Expand Up @@ -148,8 +156,15 @@ def generate_wrapper(self) -> str:
default_args += f', py::arg("{arg.name}")'

if arg.default_value is not None:
# TODO: Fix <DIM> in default args (see method_writer)
default_args += f" = {arg.default_value}"
default_value = str(arg.default_value)

if self.template_params:
for param, val in zip(self.template_params, self.template_args):
default_value = default_value.replace(
self.class_info.name + "::" + param, str(val)
).replace(param, str(val))

default_args += f" = {default_value}"

wrapper_string += default_args + ")\n"

Expand Down
52 changes: 27 additions & 25 deletions cppwg/writers/method_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,7 @@
from typing import Dict, Optional

from pygccxml import declarations
from pygccxml.declarations.calldef_members import member_function_t
from pygccxml.declarations.class_declaration import class_t

from cppwg.input.class_info import CppClassInfo
from cppwg.writers.base_writer import CppBaseWrapperWriter


Expand All @@ -18,35 +15,46 @@ class CppMethodWrapperWriter(CppBaseWrapperWriter):
----------
class_info : ClassInfo
The class information for the class containing the method
method_decl : member_function_t
template_idx: int
The index of the template in class_info
method_decl : [pygccxml.declarations.member_function_t]
The pygccxml declaration object for the method
class_decl : class_t
class_decl : [pygccxml.declarations.class_t]
The class declaration for the class containing the method
wrapper_templates : Dict[str, str]
String templates with placeholders for generating wrapper code
class_short_name : Optional[str]
The short name of the class e.g. 'Foo2_2'
template_params: Optional[List[str]]
The template params for the class e.g. ['DIM_A', 'DIM_B']
template_args: Optional[List[str]]
The template args for the class e.g. ['2', '2']
"""

def __init__(
self,
class_info: CppClassInfo,
method_decl: member_function_t,
class_decl: class_t,
class_info: "CppClassInfo", # noqa: F821
template_idx: int,
method_decl: "member_function_t", # noqa: F821
wrapper_templates: Dict[str, str],
class_short_name: Optional[str] = None,
) -> None:

super(CppMethodWrapperWriter, self).__init__(wrapper_templates)

self.class_info: CppClassInfo = class_info
self.method_decl: member_function_t = method_decl
self.class_decl: class_t = class_decl
self.class_info: "CppClassInfo" = class_info # noqa: F821
self.method_decl: "member_function_t" = method_decl # noqa: F821
self.class_decl: "class_t" = class_info.decls[template_idx] # noqa: F821

self.class_short_name: str = class_short_name
self.class_short_name = class_info.short_names[template_idx]
if self.class_short_name is None:
self.class_short_name = self.class_decl.name

self.template_params = class_info.template_params

self.template_args = None
if class_info.template_arg_lists:
self.template_args = class_info.template_arg_lists[template_idx]

def exclusion_criteria(self) -> bool:
"""
Check if the method should be excluded from the wrapper code.
Expand Down Expand Up @@ -136,23 +144,17 @@ def generate_wrapper(self) -> str:
# Default args e.g. py::arg("d") = 1.0
default_args = ""
if not self.default_arg_exclusion_criteria():
for arg, arg_type in zip(
self.method_decl.arguments, self.method_decl.argument_types
):
for arg in self.method_decl.arguments:
default_args += f', py::arg("{arg.name}")'

if arg.default_value is not None:
default_value = str(arg.default_value)

# Hack for missing template in default args
# e.g. Foo<2>::bar(Bar<2> const & b = Bar<DIM>())
# TODO: Make more robust
arg_type_str = str(arg_type).replace(" ", "")
if "<DIM>" in default_value:
if "<2>" in arg_type_str:
default_value = default_value.replace("<DIM>", "<2>")
elif "<3>" in arg_type_str:
default_value = default_value.replace("<DIM>", "<3>")
if self.template_params:
for param, val in zip(self.template_params, self.template_args):
default_value = default_value.replace(
self.class_info.name + "::" + param, str(val)
).replace(param, str(val))

default_args += f" = {default_value}"

Expand Down

0 comments on commit 7efe7ea

Please sign in to comment.