Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Wrapper Update #880

Merged
merged 2 commits into from
Sep 28, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion wrap/DOCS.md
Original file line number Diff line number Diff line change
Expand Up @@ -133,9 +133,10 @@ The python wrapper supports keyword arguments for functions/methods. Hence, the
template<T, U> class Class2 { ... };
typedef Class2<Type1, Type2> MyInstantiatedClass;
```
- Templates can also be defined for methods, properties and static methods.
- Templates can also be defined for constructors, methods, properties and static methods.
- In the class definition, appearances of the template argument(s) will be replaced with their
instantiated types, e.g. `void setValue(const T& value);`.
- Values scoped within templates are supported. E.g. one can use the form `T::Value` where T is a template, as an argument to a method.
- To refer to the instantiation of the template class itself, use `This`, i.e. `static This Create();`.
- To create new instantiations in other modules, you must copy-and-paste the whole class definition
into the new module, but use only your new instantiation types.
Expand Down
35 changes: 34 additions & 1 deletion wrap/gtwrap/template_instantiator.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,17 @@
import gtwrap.interface_parser as parser


def is_scoped_template(template_typenames, str_arg_typename):
"""
Check if the template given by `str_arg_typename` is a scoped template,
and if so, return what template and index matches the scoped template correctly.
"""
for idx, template in enumerate(template_typenames):
if template in str_arg_typename.split("::"):
return template, idx
return False, -1


def instantiate_type(ctype: parser.Type,
template_typenames: List[str],
instantiations: List[parser.Typename],
Expand Down Expand Up @@ -41,9 +52,30 @@ def instantiate_type(ctype: parser.Type,

str_arg_typename = str(ctype.typename)

# Check if template is a scoped template e.g. T::Value where T is the template
scoped_template, scoped_idx = is_scoped_template(template_typenames,
str_arg_typename)

# Instantiate templates which have enumerated instantiations in the template.
# E.g. `template<T={double}>`.
if str_arg_typename in template_typenames:

# Instantiate scoped templates, e.g. T::Value.
if scoped_template:
# Create a copy of the instantiation so we can modify it.
instantiation = deepcopy(instantiations[scoped_idx])
# Replace the part of the template with the instantiation
instantiation.name = str_arg_typename.replace(scoped_template,
instantiation.name)
return parser.Type(
typename=instantiation,
is_const=ctype.is_const,
is_shared_ptr=ctype.is_shared_ptr,
is_ptr=ctype.is_ptr,
is_ref=ctype.is_ref,
is_basic=ctype.is_basic,
)
# Check for exact template match.
elif str_arg_typename in template_typenames:
idx = template_typenames.index(str_arg_typename)
return parser.Type(
typename=instantiations[idx],
Expand Down Expand Up @@ -418,6 +450,7 @@ def instantiate(instantiated_ctors, ctor, typenames, instantiations):
ctor,
typenames=typenames,
instantiations=self.instantiations)

return instantiated_ctors

def instantiate_static_methods(self, typenames):
Expand Down
225 changes: 225 additions & 0 deletions wrap/tests/expected/matlab/template_wrapper.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,225 @@
#include <gtwrap/matlab.h>
#include <map>

#include <boost/archive/text_iarchive.hpp>
#include <boost/archive/text_oarchive.hpp>
#include <boost/serialization/export.hpp>



typedef ScopedTemplate<Result> ScopedTemplateResult;

typedef std::set<boost::shared_ptr<TemplatedConstructor>*> Collector_TemplatedConstructor;
static Collector_TemplatedConstructor collector_TemplatedConstructor;
typedef std::set<boost::shared_ptr<ScopedTemplateResult>*> Collector_ScopedTemplateResult;
static Collector_ScopedTemplateResult collector_ScopedTemplateResult;


void _deleteAllObjects()
{
mstream mout;
std::streambuf *outbuf = std::cout.rdbuf(&mout);

bool anyDeleted = false;
{ for(Collector_TemplatedConstructor::iterator iter = collector_TemplatedConstructor.begin();
iter != collector_TemplatedConstructor.end(); ) {
delete *iter;
collector_TemplatedConstructor.erase(iter++);
anyDeleted = true;
} }
{ for(Collector_ScopedTemplateResult::iterator iter = collector_ScopedTemplateResult.begin();
iter != collector_ScopedTemplateResult.end(); ) {
delete *iter;
collector_ScopedTemplateResult.erase(iter++);
anyDeleted = true;
} }

if(anyDeleted)
cout <<
"WARNING: Wrap modules with variables in the workspace have been reloaded due to\n"
"calling destructors, call 'clear all' again if you plan to now recompile a wrap\n"
"module, so that your recompiled module is used instead of the old one." << endl;
std::cout.rdbuf(outbuf);
}

void _template_RTTIRegister() {
const mxArray *alreadyCreated = mexGetVariablePtr("global", "gtsam_template_rttiRegistry_created");
if(!alreadyCreated) {
std::map<std::string, std::string> types;



mxArray *registry = mexGetVariable("global", "gtsamwrap_rttiRegistry");
if(!registry)
registry = mxCreateStructMatrix(1, 1, 0, NULL);
typedef std::pair<std::string, std::string> StringPair;
for(const StringPair& rtti_matlab: types) {
int fieldId = mxAddField(registry, rtti_matlab.first.c_str());
if(fieldId < 0) {
mexErrMsgTxt("gtsam wrap: Error indexing RTTI types, inheritance will not work correctly");
}
mxArray *matlabName = mxCreateString(rtti_matlab.second.c_str());
mxSetFieldByNumber(registry, 0, fieldId, matlabName);
}
if(mexPutVariable("global", "gtsamwrap_rttiRegistry", registry) != 0) {
mexErrMsgTxt("gtsam wrap: Error indexing RTTI types, inheritance will not work correctly");
}
mxDestroyArray(registry);

mxArray *newAlreadyCreated = mxCreateNumericMatrix(0, 0, mxINT8_CLASS, mxREAL);
if(mexPutVariable("global", "gtsam_geometry_rttiRegistry_created", newAlreadyCreated) != 0) {
mexErrMsgTxt("gtsam wrap: Error indexing RTTI types, inheritance will not work correctly");
}
mxDestroyArray(newAlreadyCreated);
}
}

void TemplatedConstructor_collectorInsertAndMakeBase_0(int nargout, mxArray *out[], int nargin, const mxArray *in[])
{
mexAtExit(&_deleteAllObjects);
typedef boost::shared_ptr<TemplatedConstructor> Shared;

Shared *self = *reinterpret_cast<Shared**> (mxGetData(in[0]));
collector_TemplatedConstructor.insert(self);
}

void TemplatedConstructor_constructor_1(int nargout, mxArray *out[], int nargin, const mxArray *in[])
{
mexAtExit(&_deleteAllObjects);
typedef boost::shared_ptr<TemplatedConstructor> Shared;

Shared *self = new Shared(new TemplatedConstructor());
collector_TemplatedConstructor.insert(self);
out[0] = mxCreateNumericMatrix(1, 1, mxUINT32OR64_CLASS, mxREAL);
*reinterpret_cast<Shared**> (mxGetData(out[0])) = self;
}

void TemplatedConstructor_constructor_2(int nargout, mxArray *out[], int nargin, const mxArray *in[])
{
mexAtExit(&_deleteAllObjects);
typedef boost::shared_ptr<TemplatedConstructor> Shared;

string& arg = *unwrap_shared_ptr< string >(in[0], "ptr_string");
Shared *self = new Shared(new TemplatedConstructor(arg));
collector_TemplatedConstructor.insert(self);
out[0] = mxCreateNumericMatrix(1, 1, mxUINT32OR64_CLASS, mxREAL);
*reinterpret_cast<Shared**> (mxGetData(out[0])) = self;
}

void TemplatedConstructor_constructor_3(int nargout, mxArray *out[], int nargin, const mxArray *in[])
{
mexAtExit(&_deleteAllObjects);
typedef boost::shared_ptr<TemplatedConstructor> Shared;

int arg = unwrap< int >(in[0]);
Shared *self = new Shared(new TemplatedConstructor(arg));
collector_TemplatedConstructor.insert(self);
out[0] = mxCreateNumericMatrix(1, 1, mxUINT32OR64_CLASS, mxREAL);
*reinterpret_cast<Shared**> (mxGetData(out[0])) = self;
}

void TemplatedConstructor_constructor_4(int nargout, mxArray *out[], int nargin, const mxArray *in[])
{
mexAtExit(&_deleteAllObjects);
typedef boost::shared_ptr<TemplatedConstructor> Shared;

double arg = unwrap< double >(in[0]);
Shared *self = new Shared(new TemplatedConstructor(arg));
collector_TemplatedConstructor.insert(self);
out[0] = mxCreateNumericMatrix(1, 1, mxUINT32OR64_CLASS, mxREAL);
*reinterpret_cast<Shared**> (mxGetData(out[0])) = self;
}

void TemplatedConstructor_deconstructor_5(int nargout, mxArray *out[], int nargin, const mxArray *in[])
{
typedef boost::shared_ptr<TemplatedConstructor> Shared;
checkArguments("delete_TemplatedConstructor",nargout,nargin,1);
Shared *self = *reinterpret_cast<Shared**>(mxGetData(in[0]));
Collector_TemplatedConstructor::iterator item;
item = collector_TemplatedConstructor.find(self);
if(item != collector_TemplatedConstructor.end()) {
delete self;
collector_TemplatedConstructor.erase(item);
}
}

void ScopedTemplateResult_collectorInsertAndMakeBase_6(int nargout, mxArray *out[], int nargin, const mxArray *in[])
{
mexAtExit(&_deleteAllObjects);
typedef boost::shared_ptr<ScopedTemplate<Result>> Shared;

Shared *self = *reinterpret_cast<Shared**> (mxGetData(in[0]));
collector_ScopedTemplateResult.insert(self);
}

void ScopedTemplateResult_constructor_7(int nargout, mxArray *out[], int nargin, const mxArray *in[])
{
mexAtExit(&_deleteAllObjects);
typedef boost::shared_ptr<ScopedTemplate<Result>> Shared;

Result::Value& arg = *unwrap_shared_ptr< Result::Value >(in[0], "ptr_Result::Value");
Shared *self = new Shared(new ScopedTemplate<Result>(arg));
collector_ScopedTemplateResult.insert(self);
out[0] = mxCreateNumericMatrix(1, 1, mxUINT32OR64_CLASS, mxREAL);
*reinterpret_cast<Shared**> (mxGetData(out[0])) = self;
}

void ScopedTemplateResult_deconstructor_8(int nargout, mxArray *out[], int nargin, const mxArray *in[])
{
typedef boost::shared_ptr<ScopedTemplate<Result>> Shared;
checkArguments("delete_ScopedTemplateResult",nargout,nargin,1);
Shared *self = *reinterpret_cast<Shared**>(mxGetData(in[0]));
Collector_ScopedTemplateResult::iterator item;
item = collector_ScopedTemplateResult.find(self);
if(item != collector_ScopedTemplateResult.end()) {
delete self;
collector_ScopedTemplateResult.erase(item);
}
}


void mexFunction(int nargout, mxArray *out[], int nargin, const mxArray *in[])
{
mstream mout;
std::streambuf *outbuf = std::cout.rdbuf(&mout);

_template_RTTIRegister();

int id = unwrap<int>(in[0]);

try {
switch(id) {
case 0:
TemplatedConstructor_collectorInsertAndMakeBase_0(nargout, out, nargin-1, in+1);
break;
case 1:
TemplatedConstructor_constructor_1(nargout, out, nargin-1, in+1);
break;
case 2:
TemplatedConstructor_constructor_2(nargout, out, nargin-1, in+1);
break;
case 3:
TemplatedConstructor_constructor_3(nargout, out, nargin-1, in+1);
break;
case 4:
TemplatedConstructor_constructor_4(nargout, out, nargin-1, in+1);
break;
case 5:
TemplatedConstructor_deconstructor_5(nargout, out, nargin-1, in+1);
break;
case 6:
ScopedTemplateResult_collectorInsertAndMakeBase_6(nargout, out, nargin-1, in+1);
break;
case 7:
ScopedTemplateResult_constructor_7(nargout, out, nargin-1, in+1);
break;
case 8:
ScopedTemplateResult_deconstructor_8(nargout, out, nargin-1, in+1);
break;
}
} catch(const std::exception& e) {
mexErrMsgTxt(("Exception from gtsam:\n" + std::string(e.what()) + "\n").c_str());
}

std::cout.rdbuf(outbuf);
}
38 changes: 38 additions & 0 deletions wrap/tests/expected/python/templates_pybind.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@


#include <pybind11/eigen.h>
#include <pybind11/stl_bind.h>
#include <pybind11/pybind11.h>
#include <pybind11/operators.h>
#include "gtsam/nonlinear/utilities.h" // for RedirectCout.


#include "wrap/serialization.h"
#include <boost/serialization/export.hpp>





using namespace std;

namespace py = pybind11;

PYBIND11_MODULE(templates_py, m_) {
m_.doc() = "pybind11 wrapper of templates_py";


py::class_<TemplatedConstructor, std::shared_ptr<TemplatedConstructor>>(m_, "TemplatedConstructor")
.def(py::init<>())
.def(py::init<const string&>(), py::arg("arg"))
.def(py::init<const int&>(), py::arg("arg"))
.def(py::init<const double&>(), py::arg("arg"));

py::class_<ScopedTemplate<Result>, std::shared_ptr<ScopedTemplate<Result>>>(m_, "ScopedTemplateResult")
.def(py::init<const Result::Value&>(), py::arg("arg"));


#include "python/specializations.h"

}

15 changes: 15 additions & 0 deletions wrap/tests/fixtures/templates.i
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
// Test for templated constructor
class TemplatedConstructor {
TemplatedConstructor();

template<T={string, int, double}>
TemplatedConstructor(const T& arg);
};

// Test for a scoped value inside a template
template <T = {Result}>
class ScopedTemplate {
// T should be properly substituted here.
ScopedTemplate(const T::Value& arg);
};

17 changes: 17 additions & 0 deletions wrap/tests/test_matlab_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,23 @@ def test_class(self):
for file in files:
self.compare_and_diff(file)

def test_templates(self):
"""Test interface file with template info."""
file = osp.join(self.INTERFACE_DIR, 'templates.i')

wrapper = MatlabWrapper(
module_name='template',
top_module_namespace=['gtsam'],
ignore_classes=[''],
)

wrapper.wrap([file], path=self.MATLAB_ACTUAL_DIR)

files = ['template_wrapper.cpp']

for file in files:
self.compare_and_diff(file)

def test_inheritance(self):
"""Test interface file with class inheritance definitions."""
file = osp.join(self.INTERFACE_DIR, 'inheritance.i')
Expand Down
8 changes: 8 additions & 0 deletions wrap/tests/test_pybind_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,14 @@ def test_class(self):

self.compare_and_diff('class_pybind.cpp', output)

def test_templates(self):
"""Test interface file with templated class."""
source = osp.join(self.INTERFACE_DIR, 'templates.i')
output = self.wrap_content([source], 'templates_py',
self.PYTHON_ACTUAL_DIR)

self.compare_and_diff('templates_pybind.cpp', output)

def test_inheritance(self):
"""Test interface file with class inheritance definitions."""
source = osp.join(self.INTERFACE_DIR, 'inheritance.i')
Expand Down