diff --git a/src/mp/gen.cpp b/src/mp/gen.cpp index 65c78f1..d98e3eb 100644 --- a/src/mp/gen.cpp +++ b/src/mp/gen.cpp @@ -62,6 +62,16 @@ static bool GetAnnotationInt32(const Reader& reader, uint64_t id, int32_t* resul return false; } +void ForEachMethod(const capnp::InterfaceSchema& interface, const std::function& callback) +{ + for (const auto super : interface.getSuperclasses()) { + ForEachMethod(super, callback); + } + for (const auto method : interface.getMethods()) { + callback(interface, method); + } +} + using CharSlice = kj::ArrayPtr; // Overload for any type with a string .begin(), like kj::StringPtr and kj::ArrayPtr. @@ -323,12 +333,13 @@ void Generate(kj::StringPtr src_prefix, std::ostringstream client_construct; std::ostringstream client_destroy; - for (const auto method : interface.getMethods()) { + int method_ordinal = 0; + ForEachMethod(interface, [&] (const capnp::InterfaceSchema& method_interface, const capnp::InterfaceSchema::Method& method) { kj::StringPtr method_name = method.getProto().getName(); kj::StringPtr proxied_method_name = method_name; GetAnnotationText(method.getProto(), NAME_ANNOTATION_ID, &proxied_method_name); - const std::string method_prefix = Format() << message_namespace << "::" << node_name + const std::string method_prefix = Format() << message_namespace << "::" << method_interface.getShortDisplayName() << "::" << Cap(method_name); bool is_construct = method_name == "construct"; bool is_destroy = method_name == "destroy"; @@ -413,7 +424,7 @@ void Generate(kj::StringPtr src_prefix, } } - if (!is_construct && !is_destroy) { + if (!is_construct && !is_destroy && (&method_interface == &interface)) { methods << "template<>\n"; methods << "struct ProxyMethod<" << method_prefix << "Params>\n"; methods << "{\n"; @@ -444,7 +455,7 @@ void Generate(kj::StringPtr src_prefix, for (int i = 0; i < field.args; ++i) { if (argc > 0) client_args << ","; - client_args << "M" << method.getOrdinal() << "::Param<" << argc << "> " << field_name; + client_args << "M" << method_ordinal << "::Param<" << argc << "> " << field_name; if (field.args > 1) client_args << i; ++argc; } @@ -481,16 +492,16 @@ void Generate(kj::StringPtr src_prefix, server_invoke_end << ")"; } - client << " using M" << method.getOrdinal() << " = ProxyClientMethodTraits<" << method_prefix + client << " using M" << method_ordinal << " = ProxyClientMethodTraits<" << method_prefix << "Params>;\n"; - client << " typename M" << method.getOrdinal() << "::Result " << method_name << "(" + client << " typename M" << method_ordinal << "::Result " << method_name << "(" << client_args.str() << ")"; client << ";\n"; - def_client << "ProxyClient<" << message_namespace << "::" << node_name << ">::M" << method.getOrdinal() + def_client << "ProxyClient<" << message_namespace << "::" << node_name << ">::M" << method_ordinal << "::Result ProxyClient<" << message_namespace << "::" << node_name << ">::" << method_name << "(" << client_args.str() << ") {\n"; if (has_result) { - def_client << " typename M" << method.getOrdinal() << "::Result result;\n"; + def_client << " typename M" << method_ordinal << "::Result result;\n"; } def_client << " clientInvoke(*this, &" << message_namespace << "::" << node_name << "::Client::" << method_name << "Request" << client_invoke.str() << ");\n"; @@ -511,7 +522,8 @@ void Generate(kj::StringPtr src_prefix, def_server << "ServerCall()"; } def_server << server_invoke_end.str() << ");\n}\n"; - } + ++method_ordinal; + }); client << "};\n"; server << "};\n"; diff --git a/test/src/mp/test/foo.capnp b/test/src/mp/test/foo.capnp index 52ad500..333c4e4 100644 --- a/test/src/mp/test/foo.capnp +++ b/test/src/mp/test/foo.capnp @@ -20,6 +20,7 @@ interface FooInterface $Proxy.wrap("mp::test::FooImplementation") { callbackShared @7 (context :Proxy.Context, callback :FooCallback, arg: Int32) -> (result :Int32); saveCallback @8 (context :Proxy.Context, callback :FooCallback) -> (); callbackSaved @9 (context :Proxy.Context, arg: Int32) -> (result :Int32); + callbackExtended @10 (context :Proxy.Context, callback :ExtendedCallback, arg: Int32) -> (result :Int32); } interface FooCallback $Proxy.wrap("mp::test::FooCallback") { @@ -27,6 +28,10 @@ interface FooCallback $Proxy.wrap("mp::test::FooCallback") { call @1 (context :Proxy.Context, arg :Int32) -> (result :Int32); } +interface ExtendedCallback extends(FooCallback) $Proxy.wrap("mp::test::ExtendedCallback") { + callExtended @0 (context :Proxy.Context, arg :Int32) -> (result :Int32); +} + struct FooStruct $Proxy.wrap("mp::test::FooStruct") { name @0 :Text; } diff --git a/test/src/mp/test/foo.h b/test/src/mp/test/foo.h index bc0289b..470facf 100644 --- a/test/src/mp/test/foo.h +++ b/test/src/mp/test/foo.h @@ -26,6 +26,12 @@ class FooCallback virtual int call(int arg) = 0; }; +class ExtendedCallback : public FooCallback +{ +public: + virtual int callExtended(int arg) = 0; +}; + class FooImplementation { public: @@ -39,6 +45,7 @@ class FooImplementation int callbackShared(std::shared_ptr callback, int arg) { return callback->call(arg); } void saveCallback(std::shared_ptr callback) { m_callback = std::move(callback); } int callbackSaved(int arg) { return m_callback->call(arg); } + int callbackExtended(ExtendedCallback& callback, int arg) { return callback.callExtended(arg); } std::shared_ptr m_callback; }; diff --git a/test/src/mp/test/test.cpp b/test/src/mp/test/test.cpp index 41c552f..9c27825 100644 --- a/test/src/mp/test/test.cpp +++ b/test/src/mp/test/test.cpp @@ -54,7 +54,7 @@ KJ_TEST("Call FooInterface methods") } KJ_EXPECT(in.name == err.name); - class Callback : public FooCallback + class Callback : public ExtendedCallback { public: Callback(int expect, int ret) : m_expect(expect), m_ret(ret) {} @@ -63,6 +63,11 @@ KJ_TEST("Call FooInterface methods") KJ_EXPECT(arg == m_expect); return m_ret; } + int callExtended(int arg) override + { + KJ_EXPECT(arg == m_expect + 10); + return m_ret + 10; + } int m_expect, m_ret; }; @@ -79,6 +84,7 @@ KJ_TEST("Call FooInterface methods") KJ_EXPECT(foo->callbackSaved(7) == 8); foo->saveCallback(nullptr); KJ_EXPECT(saved.use_count() == 1); + KJ_EXPECT(foo->callbackExtended(callback, 11) == 12); disconnect_client(); thread.join();