Skip to content

Commit

Permalink
Merge #38: Add "extends" inherited method support
Browse files Browse the repository at this point in the history
de748be Add "extends" inherited method support (Russell Yanofsky)

Pull request description:

Top commit has no ACKs.

Tree-SHA512: 2d85996f773ff3a0cfae01096b0358339a3ab50595531002be0be24c80051f7580a208c666cb8ffc603a5233d2f353ee21e4c127bac464200232f9e3b7bfefe7
  • Loading branch information
ryanofsky committed Sep 13, 2020
2 parents 9f5b835 + de748be commit 4c59977
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 10 deletions.
30 changes: 21 additions & 9 deletions src/mp/gen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,16 @@ static bool GetAnnotationInt32(const Reader& reader, uint64_t id, int32_t* resul
return false;
}

void ForEachMethod(const capnp::InterfaceSchema& interface, const std::function<void(const capnp::InterfaceSchema& interface, const capnp::InterfaceSchema::Method)>& callback)
{
for (const auto super : interface.getSuperclasses()) {
ForEachMethod(super, callback);
}
for (const auto method : interface.getMethods()) {
callback(interface, method);
}
}

using CharSlice = kj::ArrayPtr<const char>;

// Overload for any type with a string .begin(), like kj::StringPtr and kj::ArrayPtr<char>.
Expand Down Expand Up @@ -323,12 +333,13 @@ void Generate(kj::StringPtr src_prefix,
std::ostringstream client_construct;
std::ostringstream client_destroy;

for (const auto method : interface.getMethods()) {
int method_ordinal = 0;
ForEachMethod(interface, [&] (const capnp::InterfaceSchema& method_interface, const capnp::InterfaceSchema::Method& method) {
kj::StringPtr method_name = method.getProto().getName();
kj::StringPtr proxied_method_name = method_name;
GetAnnotationText(method.getProto(), NAME_ANNOTATION_ID, &proxied_method_name);

const std::string method_prefix = Format() << message_namespace << "::" << node_name
const std::string method_prefix = Format() << message_namespace << "::" << method_interface.getShortDisplayName()
<< "::" << Cap(method_name);
bool is_construct = method_name == "construct";
bool is_destroy = method_name == "destroy";
Expand Down Expand Up @@ -413,7 +424,7 @@ void Generate(kj::StringPtr src_prefix,
}
}

if (!is_construct && !is_destroy) {
if (!is_construct && !is_destroy && (&method_interface == &interface)) {
methods << "template<>\n";
methods << "struct ProxyMethod<" << method_prefix << "Params>\n";
methods << "{\n";
Expand Down Expand Up @@ -444,7 +455,7 @@ void Generate(kj::StringPtr src_prefix,

for (int i = 0; i < field.args; ++i) {
if (argc > 0) client_args << ",";
client_args << "M" << method.getOrdinal() << "::Param<" << argc << "> " << field_name;
client_args << "M" << method_ordinal << "::Param<" << argc << "> " << field_name;
if (field.args > 1) client_args << i;
++argc;
}
Expand Down Expand Up @@ -481,16 +492,16 @@ void Generate(kj::StringPtr src_prefix,
server_invoke_end << ")";
}

client << " using M" << method.getOrdinal() << " = ProxyClientMethodTraits<" << method_prefix
client << " using M" << method_ordinal << " = ProxyClientMethodTraits<" << method_prefix
<< "Params>;\n";
client << " typename M" << method.getOrdinal() << "::Result " << method_name << "("
client << " typename M" << method_ordinal << "::Result " << method_name << "("
<< client_args.str() << ")";
client << ";\n";
def_client << "ProxyClient<" << message_namespace << "::" << node_name << ">::M" << method.getOrdinal()
def_client << "ProxyClient<" << message_namespace << "::" << node_name << ">::M" << method_ordinal
<< "::Result ProxyClient<" << message_namespace << "::" << node_name << ">::" << method_name
<< "(" << client_args.str() << ") {\n";
if (has_result) {
def_client << " typename M" << method.getOrdinal() << "::Result result;\n";
def_client << " typename M" << method_ordinal << "::Result result;\n";
}
def_client << " clientInvoke(*this, &" << message_namespace << "::" << node_name
<< "::Client::" << method_name << "Request" << client_invoke.str() << ");\n";
Expand All @@ -511,7 +522,8 @@ void Generate(kj::StringPtr src_prefix,
def_server << "ServerCall()";
}
def_server << server_invoke_end.str() << ");\n}\n";
}
++method_ordinal;
});

client << "};\n";
server << "};\n";
Expand Down
5 changes: 5 additions & 0 deletions test/src/mp/test/foo.capnp
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,18 @@ interface FooInterface $Proxy.wrap("mp::test::FooImplementation") {
callbackShared @7 (context :Proxy.Context, callback :FooCallback, arg: Int32) -> (result :Int32);
saveCallback @8 (context :Proxy.Context, callback :FooCallback) -> ();
callbackSaved @9 (context :Proxy.Context, arg: Int32) -> (result :Int32);
callbackExtended @10 (context :Proxy.Context, callback :ExtendedCallback, arg: Int32) -> (result :Int32);
}

interface FooCallback $Proxy.wrap("mp::test::FooCallback") {
destroy @0 (context :Proxy.Context) -> ();
call @1 (context :Proxy.Context, arg :Int32) -> (result :Int32);
}

interface ExtendedCallback extends(FooCallback) $Proxy.wrap("mp::test::ExtendedCallback") {
callExtended @0 (context :Proxy.Context, arg :Int32) -> (result :Int32);
}

struct FooStruct $Proxy.wrap("mp::test::FooStruct") {
name @0 :Text;
}
Expand Down
7 changes: 7 additions & 0 deletions test/src/mp/test/foo.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,12 @@ class FooCallback
virtual int call(int arg) = 0;
};

class ExtendedCallback : public FooCallback
{
public:
virtual int callExtended(int arg) = 0;
};

class FooImplementation
{
public:
Expand All @@ -39,6 +45,7 @@ class FooImplementation
int callbackShared(std::shared_ptr<FooCallback> callback, int arg) { return callback->call(arg); }
void saveCallback(std::shared_ptr<FooCallback> callback) { m_callback = std::move(callback); }
int callbackSaved(int arg) { return m_callback->call(arg); }
int callbackExtended(ExtendedCallback& callback, int arg) { return callback.callExtended(arg); }
std::shared_ptr<FooCallback> m_callback;
};

Expand Down
8 changes: 7 additions & 1 deletion test/src/mp/test/test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ KJ_TEST("Call FooInterface methods")
}
KJ_EXPECT(in.name == err.name);

class Callback : public FooCallback
class Callback : public ExtendedCallback
{
public:
Callback(int expect, int ret) : m_expect(expect), m_ret(ret) {}
Expand All @@ -63,6 +63,11 @@ KJ_TEST("Call FooInterface methods")
KJ_EXPECT(arg == m_expect);
return m_ret;
}
int callExtended(int arg) override
{
KJ_EXPECT(arg == m_expect + 10);
return m_ret + 10;
}
int m_expect, m_ret;
};

Expand All @@ -79,6 +84,7 @@ KJ_TEST("Call FooInterface methods")
KJ_EXPECT(foo->callbackSaved(7) == 8);
foo->saveCallback(nullptr);
KJ_EXPECT(saved.use_count() == 1);
KJ_EXPECT(foo->callbackExtended(callback, 11) == 12);

disconnect_client();
thread.join();
Expand Down

0 comments on commit 4c59977

Please sign in to comment.