diff --git a/tests/integration_tests/tests/interceptor.rs b/tests/integration_tests/tests/interceptor.rs new file mode 100644 index 000000000..062e289c0 --- /dev/null +++ b/tests/integration_tests/tests/interceptor.rs @@ -0,0 +1,53 @@ +use std::time::Duration; + +use futures::{channel::oneshot, FutureExt}; +use integration_tests::pb::{test_client::TestClient, test_server, Input, Output}; +use tonic::{ + transport::{Endpoint, Server}, + GrpcMethod, Request, Response, Status, +}; + +#[tokio::test] +async fn interceptor_retrieves_grpc_method() { + use test_server::Test; + + struct Svc; + + #[tonic::async_trait] + impl Test for Svc { + async fn unary_call(&self, _: Request) -> Result, Status> { + Ok(Response::new(Output {})) + } + } + + let svc = test_server::TestServer::new(Svc); + + let (tx, rx) = oneshot::channel(); + // Start the server now, second call should succeed + let jh = tokio::spawn(async move { + Server::builder() + .add_service(svc) + .serve_with_shutdown("127.0.0.1:1340".parse().unwrap(), rx.map(drop)) + .await + .unwrap(); + }); + + let channel = Endpoint::from_static("http://127.0.0.1:1340").connect_lazy(); + + fn client_intercept(req: Request<()>) -> Result, Status> { + println!("Intercepting client request: {:?}", req); + + let gm = req.extensions().get::().unwrap(); + assert_eq!(gm.service(), "test.Test"); + assert_eq!(gm.method(), "UnaryCall"); + + Ok(req) + } + let mut client = TestClient::with_interceptor(channel, client_intercept); + + tokio::time::sleep(Duration::from_millis(100)).await; + client.unary_call(Request::new(Input {})).await.unwrap(); + + tx.send(()).unwrap(); + jh.await.unwrap(); +} diff --git a/tonic-build/src/client.rs b/tonic-build/src/client.rs index ce35d6616..417474916 100644 --- a/tonic-build/src/client.rs +++ b/tonic-build/src/client.rs @@ -1,7 +1,10 @@ use std::collections::HashSet; use super::{Attributes, Method, Service}; -use crate::{format_method_name, generate_doc_comments, naive_snake_case}; +use crate::{ + format_method_name, format_method_path, format_service_name, generate_doc_comments, + naive_snake_case, +}; use proc_macro2::TokenStream; use quote::{format_ident, quote}; @@ -51,21 +54,16 @@ pub(crate) fn generate_internal( let connect = generate_connect(&service_ident, build_transport); let package = if emit_package { service.package() } else { "" }; - let path = format!( - "{}{}{}", - package, - if package.is_empty() { "" } else { "." }, - service.identifier() - ); + let service_name = format_service_name(service, emit_package); - let service_doc = if disable_comments.contains(&path) { + let service_doc = if disable_comments.contains(&service_name) { TokenStream::new() } else { generate_doc_comments(service.comment()) }; let mod_attributes = attributes.for_mod(package); - let struct_attributes = attributes.for_struct(&path); + let struct_attributes = attributes.for_struct(&service_name); quote! { /// Generated client implementations. @@ -193,30 +191,41 @@ fn generate_methods( disable_comments: &HashSet, ) -> TokenStream { let mut stream = TokenStream::new(); - let package = if emit_package { service.package() } else { "" }; for method in service.methods() { - let path = format!( - "/{}{}{}/{}", - package, - if package.is_empty() { "" } else { "." }, - service.identifier(), - method.identifier() - ); - - if !disable_comments.contains(&format_method_name(package, service, method)) { + if !disable_comments.contains(&format_method_name(service, method, emit_package)) { stream.extend(generate_doc_comments(method.comment())); } let method = match (method.client_streaming(), method.server_streaming()) { - (false, false) => generate_unary(method, proto_path, compile_well_known_types, path), - (false, true) => { - generate_server_streaming(method, proto_path, compile_well_known_types, path) - } - (true, false) => { - generate_client_streaming(method, proto_path, compile_well_known_types, path) - } - (true, true) => generate_streaming(method, proto_path, compile_well_known_types, path), + (false, false) => generate_unary( + service, + method, + emit_package, + proto_path, + compile_well_known_types, + ), + (false, true) => generate_server_streaming( + service, + method, + emit_package, + proto_path, + compile_well_known_types, + ), + (true, false) => generate_client_streaming( + service, + method, + emit_package, + proto_path, + compile_well_known_types, + ), + (true, true) => generate_streaming( + service, + method, + emit_package, + proto_path, + compile_well_known_types, + ), }; stream.extend(method); @@ -225,15 +234,19 @@ fn generate_methods( stream } -fn generate_unary( - method: &T, +fn generate_unary( + service: &T, + method: &T::Method, + emit_package: bool, proto_path: &str, compile_well_known_types: bool, - path: String, ) -> TokenStream { let codec_name = syn::parse_str::(method.codec_path()).unwrap(); let ident = format_ident!("{}", method.name()); let (request, response) = method.request_response_name(proto_path, compile_well_known_types); + let service_name = format_service_name(service, emit_package); + let path = format_method_path(service, method, emit_package); + let method_name = method.identifier(); quote! { pub async fn #ident( @@ -245,21 +258,26 @@ fn generate_unary( })?; let codec = #codec_name::default(); let path = http::uri::PathAndQuery::from_static(#path); - self.inner.unary(request.into_request(), path, codec).await + let mut req = request.into_request(); + req.extensions_mut().insert(GrpcMethod::new(#service_name, #method_name)); + self.inner.unary(req, path, codec).await } } } -fn generate_server_streaming( - method: &T, +fn generate_server_streaming( + service: &T, + method: &T::Method, + emit_package: bool, proto_path: &str, compile_well_known_types: bool, - path: String, ) -> TokenStream { let codec_name = syn::parse_str::(method.codec_path()).unwrap(); let ident = format_ident!("{}", method.name()); - let (request, response) = method.request_response_name(proto_path, compile_well_known_types); + let service_name = format_service_name(service, emit_package); + let path = format_method_path(service, method, emit_package); + let method_name = method.identifier(); quote! { pub async fn #ident( @@ -271,21 +289,26 @@ fn generate_server_streaming( })?; let codec = #codec_name::default(); let path = http::uri::PathAndQuery::from_static(#path); - self.inner.server_streaming(request.into_request(), path, codec).await + let mut req = request.into_request(); + req.extensions_mut().insert(GrpcMethod::new(#service_name, #method_name)); + self.inner.server_streaming(req, path, codec).await } } } -fn generate_client_streaming( - method: &T, +fn generate_client_streaming( + service: &T, + method: &T::Method, + emit_package: bool, proto_path: &str, compile_well_known_types: bool, - path: String, ) -> TokenStream { let codec_name = syn::parse_str::(method.codec_path()).unwrap(); let ident = format_ident!("{}", method.name()); - let (request, response) = method.request_response_name(proto_path, compile_well_known_types); + let service_name = format_service_name(service, emit_package); + let path = format_method_path(service, method, emit_package); + let method_name = method.identifier(); quote! { pub async fn #ident( @@ -297,21 +320,26 @@ fn generate_client_streaming( })?; let codec = #codec_name::default(); let path = http::uri::PathAndQuery::from_static(#path); - self.inner.client_streaming(request.into_streaming_request(), path, codec).await + let mut req = request.into_streaming_request(); + req.extensions_mut().insert(GrpcMethod::new(#service_name, #method_name)); + self.inner.client_streaming(req, path, codec).await } } } -fn generate_streaming( - method: &T, +fn generate_streaming( + service: &T, + method: &T::Method, + emit_package: bool, proto_path: &str, compile_well_known_types: bool, - path: String, ) -> TokenStream { let codec_name = syn::parse_str::(method.codec_path()).unwrap(); let ident = format_ident!("{}", method.name()); - let (request, response) = method.request_response_name(proto_path, compile_well_known_types); + let service_name = format_service_name(service, emit_package); + let path = format_method_path(service, method, emit_package); + let method_name = method.identifier(); quote! { pub async fn #ident( @@ -323,7 +351,9 @@ fn generate_streaming( })?; let codec = #codec_name::default(); let path = http::uri::PathAndQuery::from_static(#path); - self.inner.streaming(request.into_streaming_request(), path, codec).await + let mut req = request.into_streaming_request(); + req.extensions_mut().insert(GrpcMethod::new(#service_name,#method_name)); + self.inner.streaming(req, path, codec).await } } } diff --git a/tonic-build/src/lib.rs b/tonic-build/src/lib.rs index fe3380aa7..8a685c451 100644 --- a/tonic-build/src/lib.rs +++ b/tonic-build/src/lib.rs @@ -197,16 +197,28 @@ impl Attributes { } } -fn format_method_name( - package: &str, - service: &T, - method: &::Method, -) -> String { +fn format_service_name(service: &T, emit_package: bool) -> String { + let package = if emit_package { service.package() } else { "" }; format!( - "{}{}{}.{}", + "{}{}{}", package, if package.is_empty() { "" } else { "." }, service.identifier(), + ) +} + +fn format_method_path(service: &T, method: &T::Method, emit_package: bool) -> String { + format!( + "/{}/{}", + format_service_name(service, emit_package), + method.identifier() + ) +} + +fn format_method_name(service: &T, method: &T::Method, emit_package: bool) -> String { + format!( + "{}.{}", + format_service_name(service, emit_package), method.identifier() ) } diff --git a/tonic-build/src/server.rs b/tonic-build/src/server.rs index e2ba3bb4e..c521ff7c1 100644 --- a/tonic-build/src/server.rs +++ b/tonic-build/src/server.rs @@ -1,7 +1,10 @@ use std::collections::HashSet; use super::{Attributes, Method, Service}; -use crate::{format_method_name, generate_doc_comment, generate_doc_comments, naive_snake_case}; +use crate::{ + format_method_name, format_method_path, format_service_name, generate_doc_comment, + generate_doc_comments, naive_snake_case, +}; use proc_macro2::{Span, TokenStream}; use quote::quote; use syn::{Ident, Lit, LitStr}; @@ -36,7 +39,7 @@ pub(crate) fn generate_internal( attributes: &Attributes, disable_comments: &HashSet, ) -> TokenStream { - let methods = generate_methods(service, proto_path, compile_well_known_types); + let methods = generate_methods(service, emit_package, proto_path, compile_well_known_types); let server_service = quote::format_ident!("{}Server", service.name()); let server_trait = quote::format_ident!("{}", service.name()); @@ -51,22 +54,17 @@ pub(crate) fn generate_internal( ); let package = if emit_package { service.package() } else { "" }; // Transport based implementations - let path = format!( - "{}{}{}", - package, - if package.is_empty() { "" } else { "." }, - service.identifier() - ); + let service_name = format_service_name(service, emit_package); - let service_doc = if disable_comments.contains(&path) { + let service_doc = if disable_comments.contains(&service_name) { TokenStream::new() } else { generate_doc_comments(service.comment()) }; - let named = generate_named(&server_service, &server_trait, &path); + let named = generate_named(&server_service, &server_trait, &service_name); let mod_attributes = attributes.for_mod(package); - let struct_attributes = attributes.for_struct(&path); + let struct_attributes = attributes.for_struct(&service_name); let configure_compression_methods = quote! { /// Enable decompressing requests with the given encoding. @@ -256,19 +254,18 @@ fn generate_trait_methods( ) -> TokenStream { let mut stream = TokenStream::new(); - let package = if emit_package { service.package() } else { "" }; for method in service.methods() { let name = quote::format_ident!("{}", method.name()); let (req_message, res_message) = method.request_response_name(proto_path, compile_well_known_types); - let method_doc = if disable_comments.contains(&format_method_name(package, service, method)) - { - TokenStream::new() - } else { - generate_doc_comments(method.comment()) - }; + let method_doc = + if disable_comments.contains(&format_method_name(service, method, emit_package)) { + TokenStream::new() + } else { + generate_doc_comments(method.comment()) + }; let method = match (method.client_streaming(), method.server_streaming()) { (false, false) => { @@ -341,23 +338,14 @@ fn generate_named( fn generate_methods( service: &T, + emit_package: bool, proto_path: &str, compile_well_known_types: bool, ) -> TokenStream { let mut stream = TokenStream::new(); for method in service.methods() { - let path = format!( - "/{}{}{}/{}", - service.package(), - if service.package().is_empty() { - "" - } else { - "." - }, - service.identifier(), - method.identifier() - ); + let path = format_method_path(service, method, emit_package); let method_path = Lit::Str(LitStr::new(&path, Span::call_site())); let ident = quote::format_ident!("{}", method.name()); let server_trait = quote::format_ident!("{}", service.name()); diff --git a/tonic-health/src/generated/grpc.health.v1.rs b/tonic-health/src/generated/grpc.health.v1.rs index f9794058c..50f583b70 100644 --- a/tonic-health/src/generated/grpc.health.v1.rs +++ b/tonic-health/src/generated/grpc.health.v1.rs @@ -148,7 +148,10 @@ pub mod health_client { let path = http::uri::PathAndQuery::from_static( "/grpc.health.v1.Health/Check", ); - self.inner.unary(request.into_request(), path, codec).await + let mut req = request.into_request(); + req.extensions_mut() + .insert(GrpcMethod::new("grpc.health.v1.Health", "Check")); + self.inner.unary(req, path, codec).await } /// Performs a watch for the serving status of the requested service. /// The server will immediately send back a message indicating the current @@ -185,7 +188,10 @@ pub mod health_client { let path = http::uri::PathAndQuery::from_static( "/grpc.health.v1.Health/Watch", ); - self.inner.server_streaming(request.into_request(), path, codec).await + let mut req = request.into_request(); + req.extensions_mut() + .insert(GrpcMethod::new("grpc.health.v1.Health", "Watch")); + self.inner.server_streaming(req, path, codec).await } } } diff --git a/tonic-reflection/src/generated/grpc.reflection.v1alpha.rs b/tonic-reflection/src/generated/grpc.reflection.v1alpha.rs index b3879091e..4aaaf7992 100644 --- a/tonic-reflection/src/generated/grpc.reflection.v1alpha.rs +++ b/tonic-reflection/src/generated/grpc.reflection.v1alpha.rs @@ -247,7 +247,15 @@ pub mod server_reflection_client { let path = http::uri::PathAndQuery::from_static( "/grpc.reflection.v1alpha.ServerReflection/ServerReflectionInfo", ); - self.inner.streaming(request.into_streaming_request(), path, codec).await + let mut req = request.into_streaming_request(); + req.extensions_mut() + .insert( + GrpcMethod::new( + "grpc.reflection.v1alpha.ServerReflection", + "ServerReflectionInfo", + ), + ); + self.inner.streaming(req, path, codec).await } } } diff --git a/tonic/src/codegen.rs b/tonic/src/codegen.rs index 29a302914..14fd57717 100644 --- a/tonic/src/codegen.rs +++ b/tonic/src/codegen.rs @@ -11,6 +11,7 @@ pub use std::task::{Context, Poll}; pub use tower_service::Service; pub type StdError = Box; pub use crate::codec::{CompressionEncoding, EnabledCompressionEncodings}; +pub use crate::extensions::GrpcMethod; pub use crate::service::interceptor::InterceptedService; pub use bytes::Bytes; pub use http; diff --git a/tonic/src/extensions.rs b/tonic/src/extensions.rs index 470b84af9..8949be929 100644 --- a/tonic/src/extensions.rs +++ b/tonic/src/extensions.rs @@ -71,3 +71,27 @@ impl fmt::Debug for Extensions { f.debug_struct("Extensions").finish() } } + +/// A gRPC Method info extension. +#[derive(Debug)] +pub struct GrpcMethod { + service: &'static str, + method: &'static str, +} + +impl GrpcMethod { + /// Create a new `GrpcMethod` extension. + #[doc(hidden)] + pub fn new(service: &'static str, method: &'static str) -> Self { + Self { service, method } + } + + /// gRPC service name + pub fn service(&self) -> &str { + &self.service + } + /// gRPC method name + pub fn method(&self) -> &str { + &self.method + } +} diff --git a/tonic/src/lib.rs b/tonic/src/lib.rs index 60fe447cd..4b00ff335 100644 --- a/tonic/src/lib.rs +++ b/tonic/src/lib.rs @@ -111,7 +111,7 @@ pub use async_trait::async_trait; #[doc(inline)] pub use codec::Streaming; -pub use extensions::Extensions; +pub use extensions::{Extensions, GrpcMethod}; pub use request::{IntoRequest, IntoStreamingRequest, Request}; pub use response::Response; pub use status::{Code, Status};