diff --git a/tests/integration_tests/tests/user_agent.rs b/tests/integration_tests/tests/user_agent.rs new file mode 100644 index 000000000..12b7b3f66 --- /dev/null +++ b/tests/integration_tests/tests/user_agent.rs @@ -0,0 +1,55 @@ +use futures_util::FutureExt; +use integration_tests::pb::{test_client, test_server, Input, Output}; +use std::time::Duration; +use tokio::sync::oneshot; +use tonic::{ + transport::{Endpoint, Server}, + Request, Response, Status, +}; + +#[tokio::test] +async fn writes_user_agent_header() { + struct Svc; + + #[tonic::async_trait] + impl test_server::Test for Svc { + async fn unary_call(&self, req: Request) -> Result, Status> { + match req.metadata().get("user-agent") { + Some(_) => Ok(Response::new(Output {})), + None => Err(Status::internal("user-agent header is missing")), + } + } + } + + let svc = test_server::TestServer::new(Svc); + + let (tx, rx) = oneshot::channel::<()>(); + + let jh = tokio::spawn(async move { + Server::builder() + .add_service(svc) + .serve_with_shutdown("127.0.0.1:1322".parse().unwrap(), rx.map(drop)) + .await + .unwrap(); + }); + + tokio::time::delay_for(Duration::from_millis(100)).await; + + let channel = Endpoint::from_static("http://127.0.0.1:1322") + .user_agent("my-client") + .expect("valid user agent") + .connect() + .await + .unwrap(); + + let mut client = test_client::TestClient::new(channel); + + match client.unary_call(Input {}).await { + Ok(_) => {} + Err(status) => panic!("{}", status.message()), + } + + tx.send(()).unwrap(); + + jh.await.unwrap(); +} diff --git a/tonic/src/transport/channel/endpoint.rs b/tonic/src/transport/channel/endpoint.rs index 5889f64c5..af38d69da 100644 --- a/tonic/src/transport/channel/endpoint.rs +++ b/tonic/src/transport/channel/endpoint.rs @@ -6,7 +6,10 @@ use super::ClientTlsConfig; use crate::transport::service::TlsConnector; use crate::transport::Error; use bytes::Bytes; -use http::uri::{InvalidUri, Uri}; +use http::{ + uri::{InvalidUri, Uri}, + HeaderValue, +}; use std::{ convert::{TryFrom, TryInto}, fmt, @@ -20,6 +23,7 @@ use tower_make::MakeConnection; #[derive(Clone)] pub struct Endpoint { pub(crate) uri: Uri, + pub(crate) user_agent: Option, pub(crate) timeout: Option, pub(crate) concurrency_limit: Option, pub(crate) rate_limit: Option<(u64, Duration)>, @@ -74,6 +78,30 @@ impl Endpoint { Ok(Self::from(uri)) } + /// Set a custom user-agent header. + /// + /// `user_agent` will be prepended to Tonic's default user-agent string (`tonic/x.x.x`). + /// It must be a value that can be converted into a valid `http::HeaderValue` or building + /// the endpoint will fail. + /// ``` + /// # use tonic::transport::Endpoint; + /// # let mut builder = Endpoint::from_static("https://example.com"); + /// builder.user_agent("Greeter").expect("Greeter should be a valid header value"); + /// // user-agent: "Greeter tonic/x.x.x" + /// ``` + pub fn user_agent(self, user_agent: T) -> Result + where + T: TryInto, + { + user_agent + .try_into() + .map(|ua| Endpoint { + user_agent: Some(ua), + ..self + }) + .map_err(|_| Error::new_invalid_user_agent()) + } + /// Apply a timeout to each request. /// /// ``` @@ -276,6 +304,7 @@ impl From for Endpoint { fn from(uri: Uri) -> Self { Self { uri, + user_agent: None, concurrency_limit: None, rate_limit: None, timeout: None, diff --git a/tonic/src/transport/error.rs b/tonic/src/transport/error.rs index 584164080..042e5172d 100644 --- a/tonic/src/transport/error.rs +++ b/tonic/src/transport/error.rs @@ -21,6 +21,7 @@ struct ErrorImpl { pub(crate) enum Kind { Transport, InvalidUri, + InvalidUserAgent, } impl Error { @@ -43,10 +44,15 @@ impl Error { Error::new(Kind::InvalidUri) } + pub(crate) fn new_invalid_user_agent() -> Self { + Error::new(Kind::InvalidUserAgent) + } + fn description(&self) -> &str { match &self.inner.kind { Kind::Transport => "transport error", Kind::InvalidUri => "invalid URI", + Kind::InvalidUserAgent => "user agent is not a valid header value", } } } diff --git a/tonic/src/transport/service/connection.rs b/tonic/src/transport/service/connection.rs index a3a934973..02935301b 100644 --- a/tonic/src/transport/service/connection.rs +++ b/tonic/src/transport/service/connection.rs @@ -1,4 +1,4 @@ -use super::{layer::ServiceBuilderExt, reconnect::Reconnect, AddOrigin}; +use super::{layer::ServiceBuilderExt, reconnect::Reconnect, AddOrigin, UserAgent}; use crate::{body::BoxBody, transport::Endpoint}; use http::Uri; use hyper::client::conn::Builder; @@ -55,6 +55,7 @@ impl Connection { let stack = ServiceBuilder::new() .layer_fn(|s| AddOrigin::new(s, endpoint.uri.clone())) + .layer_fn(|s| UserAgent::new(s, endpoint.user_agent.clone())) .optional_layer(endpoint.timeout.map(TimeoutLayer::new)) .optional_layer(endpoint.concurrency_limit.map(ConcurrencyLimitLayer::new)) .optional_layer(endpoint.rate_limit.map(|(l, d)| RateLimitLayer::new(l, d))) diff --git a/tonic/src/transport/service/mod.rs b/tonic/src/transport/service/mod.rs index 92453cdbf..eab3b40ef 100644 --- a/tonic/src/transport/service/mod.rs +++ b/tonic/src/transport/service/mod.rs @@ -8,6 +8,7 @@ mod reconnect; mod router; #[cfg(feature = "tls")] mod tls; +mod user_agent; pub(crate) use self::add_origin::AddOrigin; pub(crate) use self::connection::Connection; @@ -18,3 +19,4 @@ pub(crate) use self::layer::ServiceBuilderExt; pub(crate) use self::router::{Or, Routes}; #[cfg(feature = "tls")] pub(crate) use self::tls::{TlsAcceptor, TlsConnector}; +pub(crate) use self::user_agent::UserAgent; diff --git a/tonic/src/transport/service/user_agent.rs b/tonic/src/transport/service/user_agent.rs new file mode 100644 index 000000000..6ceaea640 --- /dev/null +++ b/tonic/src/transport/service/user_agent.rs @@ -0,0 +1,70 @@ +use http::{header::USER_AGENT, HeaderValue, Request}; +use std::task::{Context, Poll}; +use tower_service::Service; + +const TONIC_USER_AGENT: &str = concat!("tonic/", env!("CARGO_PKG_VERSION")); + +#[derive(Debug)] +pub(crate) struct UserAgent { + inner: T, + user_agent: HeaderValue, +} + +impl UserAgent { + pub(crate) fn new(inner: T, user_agent: Option) -> Self { + let user_agent = user_agent + .map(|value| { + let mut buf = Vec::new(); + buf.extend(value.as_bytes()); + buf.push(b' '); + buf.extend(TONIC_USER_AGENT.as_bytes()); + HeaderValue::from_bytes(&buf).expect("user-agent should be valid") + }) + .unwrap_or(HeaderValue::from_static(TONIC_USER_AGENT)); + + Self { inner, user_agent } + } +} + +impl Service> for UserAgent +where + T: Service>, +{ + type Response = T::Response; + type Error = T::Error; + type Future = T::Future; + + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + self.inner.poll_ready(cx) + } + + fn call(&mut self, mut req: Request) -> Self::Future { + req.headers_mut() + .insert(USER_AGENT, self.user_agent.clone()); + + self.inner.call(req) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + struct Svc; + + #[test] + fn sets_default_if_no_custom_user_agent() { + assert_eq!( + UserAgent::new(Svc, None).user_agent, + HeaderValue::from_static(TONIC_USER_AGENT) + ) + } + + #[test] + fn prepends_custom_user_agent_to_default() { + assert_eq!( + UserAgent::new(Svc, Some(HeaderValue::from_static("Greeter 1.1"))).user_agent, + HeaderValue::from_str(&format!("Greeter 1.1 {}", TONIC_USER_AGENT)).unwrap() + ) + } +}