Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add "extends" inherited method support #38

Merged
merged 1 commit into from
Sep 13, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 21 additions & 9 deletions src/mp/gen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,16 @@ static bool GetAnnotationInt32(const Reader& reader, uint64_t id, int32_t* resul
return false;
}

void ForEachMethod(const capnp::InterfaceSchema& interface, const std::function<void(const capnp::InterfaceSchema& interface, const capnp::InterfaceSchema::Method)>& callback)
{
for (const auto super : interface.getSuperclasses()) {
ForEachMethod(super, callback);
}
for (const auto method : interface.getMethods()) {
callback(interface, method);
}
}

using CharSlice = kj::ArrayPtr<const char>;

// Overload for any type with a string .begin(), like kj::StringPtr and kj::ArrayPtr<char>.
Expand Down Expand Up @@ -323,12 +333,13 @@ void Generate(kj::StringPtr src_prefix,
std::ostringstream client_construct;
std::ostringstream client_destroy;

for (const auto method : interface.getMethods()) {
int method_ordinal = 0;
ForEachMethod(interface, [&] (const capnp::InterfaceSchema& method_interface, const capnp::InterfaceSchema::Method& method) {
kj::StringPtr method_name = method.getProto().getName();
kj::StringPtr proxied_method_name = method_name;
GetAnnotationText(method.getProto(), NAME_ANNOTATION_ID, &proxied_method_name);

const std::string method_prefix = Format() << message_namespace << "::" << node_name
const std::string method_prefix = Format() << message_namespace << "::" << method_interface.getShortDisplayName()
<< "::" << Cap(method_name);
bool is_construct = method_name == "construct";
bool is_destroy = method_name == "destroy";
Expand Down Expand Up @@ -413,7 +424,7 @@ void Generate(kj::StringPtr src_prefix,
}
}

if (!is_construct && !is_destroy) {
if (!is_construct && !is_destroy && (&method_interface == &interface)) {
methods << "template<>\n";
methods << "struct ProxyMethod<" << method_prefix << "Params>\n";
methods << "{\n";
Expand Down Expand Up @@ -444,7 +455,7 @@ void Generate(kj::StringPtr src_prefix,

for (int i = 0; i < field.args; ++i) {
if (argc > 0) client_args << ",";
client_args << "M" << method.getOrdinal() << "::Param<" << argc << "> " << field_name;
client_args << "M" << method_ordinal << "::Param<" << argc << "> " << field_name;
if (field.args > 1) client_args << i;
++argc;
}
Expand Down Expand Up @@ -481,16 +492,16 @@ void Generate(kj::StringPtr src_prefix,
server_invoke_end << ")";
}

client << " using M" << method.getOrdinal() << " = ProxyClientMethodTraits<" << method_prefix
client << " using M" << method_ordinal << " = ProxyClientMethodTraits<" << method_prefix
<< "Params>;\n";
client << " typename M" << method.getOrdinal() << "::Result " << method_name << "("
client << " typename M" << method_ordinal << "::Result " << method_name << "("
<< client_args.str() << ")";
client << ";\n";
def_client << "ProxyClient<" << message_namespace << "::" << node_name << ">::M" << method.getOrdinal()
def_client << "ProxyClient<" << message_namespace << "::" << node_name << ">::M" << method_ordinal
<< "::Result ProxyClient<" << message_namespace << "::" << node_name << ">::" << method_name
<< "(" << client_args.str() << ") {\n";
if (has_result) {
def_client << " typename M" << method.getOrdinal() << "::Result result;\n";
def_client << " typename M" << method_ordinal << "::Result result;\n";
}
def_client << " clientInvoke(*this, &" << message_namespace << "::" << node_name
<< "::Client::" << method_name << "Request" << client_invoke.str() << ");\n";
Expand All @@ -511,7 +522,8 @@ void Generate(kj::StringPtr src_prefix,
def_server << "ServerCall()";
}
def_server << server_invoke_end.str() << ");\n}\n";
}
++method_ordinal;
});

client << "};\n";
server << "};\n";
Expand Down
5 changes: 5 additions & 0 deletions test/src/mp/test/foo.capnp
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,18 @@ interface FooInterface $Proxy.wrap("mp::test::FooImplementation") {
callbackShared @7 (context :Proxy.Context, callback :FooCallback, arg: Int32) -> (result :Int32);
saveCallback @8 (context :Proxy.Context, callback :FooCallback) -> ();
callbackSaved @9 (context :Proxy.Context, arg: Int32) -> (result :Int32);
callbackExtended @10 (context :Proxy.Context, callback :ExtendedCallback, arg: Int32) -> (result :Int32);
}

interface FooCallback $Proxy.wrap("mp::test::FooCallback") {
destroy @0 (context :Proxy.Context) -> ();
call @1 (context :Proxy.Context, arg :Int32) -> (result :Int32);
}

interface ExtendedCallback extends(FooCallback) $Proxy.wrap("mp::test::ExtendedCallback") {
callExtended @0 (context :Proxy.Context, arg :Int32) -> (result :Int32);
}

struct FooStruct $Proxy.wrap("mp::test::FooStruct") {
name @0 :Text;
}
Expand Down
7 changes: 7 additions & 0 deletions test/src/mp/test/foo.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,12 @@ class FooCallback
virtual int call(int arg) = 0;
};

class ExtendedCallback : public FooCallback
{
public:
virtual int callExtended(int arg) = 0;
};

class FooImplementation
{
public:
Expand All @@ -39,6 +45,7 @@ class FooImplementation
int callbackShared(std::shared_ptr<FooCallback> callback, int arg) { return callback->call(arg); }
void saveCallback(std::shared_ptr<FooCallback> callback) { m_callback = std::move(callback); }
int callbackSaved(int arg) { return m_callback->call(arg); }
int callbackExtended(ExtendedCallback& callback, int arg) { return callback.callExtended(arg); }
std::shared_ptr<FooCallback> m_callback;
};

Expand Down
8 changes: 7 additions & 1 deletion test/src/mp/test/test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ KJ_TEST("Call FooInterface methods")
}
KJ_EXPECT(in.name == err.name);

class Callback : public FooCallback
class Callback : public ExtendedCallback
{
public:
Callback(int expect, int ret) : m_expect(expect), m_ret(ret) {}
Expand All @@ -63,6 +63,11 @@ KJ_TEST("Call FooInterface methods")
KJ_EXPECT(arg == m_expect);
return m_ret;
}
int callExtended(int arg) override
{
KJ_EXPECT(arg == m_expect + 10);
return m_ret + 10;
}
int m_expect, m_ret;
};

Expand All @@ -79,6 +84,7 @@ KJ_TEST("Call FooInterface methods")
KJ_EXPECT(foo->callbackSaved(7) == 8);
foo->saveCallback(nullptr);
KJ_EXPECT(saved.use_count() == 1);
KJ_EXPECT(foo->callbackExtended(callback, 11) == 12);

disconnect_client();
thread.join();
Expand Down