Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(transport): add user-agent header to client requests. #457

Merged
merged 6 commits into from
Sep 23, 2020
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 53 additions & 0 deletions tests/integration_tests/tests/user_agent.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
use futures_util::FutureExt;
use integration_tests::pb::{test_client, test_server, Input, Output};
use std::time::Duration;
use tokio::sync::oneshot;
use tonic::transport::Endpoint;
use tonic::{transport::Server, Request, Response, Status};
alce marked this conversation as resolved.
Show resolved Hide resolved

#[tokio::test]
async fn writes_user_agent_header() {
struct Svc;

#[tonic::async_trait]
impl test_server::Test for Svc {
async fn unary_call(&self, req: Request<Input>) -> Result<Response<Output>, Status> {
match req.metadata().get("user-agent") {
Some(_) => Ok(Response::new(Output {})),
None => Err(Status::internal("user-agent header is missing")),
}
}
}

let svc = test_server::TestServer::new(Svc);

let (tx, rx) = oneshot::channel::<()>();

let jh = tokio::spawn(async move {
Server::builder()
.add_service(svc)
.serve_with_shutdown("127.0.0.1:1322".parse().unwrap(), rx.map(drop))
.await
.unwrap();
});

tokio::time::delay_for(Duration::from_millis(100)).await;

let channel = Endpoint::from_static("http://127.0.0.1:1322")
.user_agent("my-client")
.expect("valid user agent")
.connect()
.await
.unwrap();

let mut client = test_client::TestClient::new(channel);

match client.unary_call(Input {}).await {
Ok(_) => {}
Err(status) => panic!("{}", status.message()),
}

tx.send(()).unwrap();

jh.await.unwrap();
}
31 changes: 30 additions & 1 deletion tonic/src/transport/channel/endpoint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,10 @@ use super::ClientTlsConfig;
use crate::transport::service::TlsConnector;
use crate::transport::Error;
use bytes::Bytes;
use http::uri::{InvalidUri, Uri};
use http::{
uri::{InvalidUri, Uri},
HeaderValue,
};
use std::{
convert::{TryFrom, TryInto},
fmt,
Expand All @@ -20,6 +23,7 @@ use tower_make::MakeConnection;
#[derive(Clone)]
pub struct Endpoint {
pub(crate) uri: Uri,
pub(crate) user_agent: Option<HeaderValue>,
pub(crate) timeout: Option<Duration>,
pub(crate) concurrency_limit: Option<usize>,
pub(crate) rate_limit: Option<(u64, Duration)>,
Expand Down Expand Up @@ -74,6 +78,30 @@ impl Endpoint {
Ok(Self::from(uri))
}

/// Set a custom user-agent header.
///
/// `user_agent` will be prepended to Tonic's default user-agent string (`tonic/x.x.x`).
/// It must be a value that can be converted into a valid `http::HeaderValue` or building
/// the endpoint will fail.
/// ```
/// # use tonic::transport::Endpoint;
/// # let mut builder = Endpoint::from_static("https://example.com");
/// builder.user_agent("Greeter").expect("Greeter should be a valid header value");
/// // user-agent: "Greeter tonic/x.x.x"
/// ```
pub fn user_agent<T>(self, user_agent: T) -> Result<Self, Error>
where
T: TryInto<HeaderValue>,
{
user_agent
.try_into()
.map(|ua| Endpoint {
user_agent: Some(ua),
..self
})
.map_err(|_| Error::new_invalid_user_agent())
}

/// Apply a timeout to each request.
///
/// ```
Expand Down Expand Up @@ -276,6 +304,7 @@ impl From<Uri> for Endpoint {
fn from(uri: Uri) -> Self {
Self {
uri,
user_agent: None,
concurrency_limit: None,
rate_limit: None,
timeout: None,
Expand Down
6 changes: 6 additions & 0 deletions tonic/src/transport/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ struct ErrorImpl {
pub(crate) enum Kind {
Transport,
InvalidUri,
InvalidUserAgent,
}

impl Error {
Expand All @@ -43,10 +44,15 @@ impl Error {
Error::new(Kind::InvalidUri)
}

pub(crate) fn new_invalid_user_agent() -> Self {
Error::new(Kind::InvalidUserAgent)
}

fn description(&self) -> &str {
match &self.inner.kind {
Kind::Transport => "transport error",
Kind::InvalidUri => "invalid URI",
Kind::InvalidUserAgent => "user agent is not a valid header value",
}
}
}
Expand Down
3 changes: 2 additions & 1 deletion tonic/src/transport/service/connection.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use super::{layer::ServiceBuilderExt, reconnect::Reconnect, AddOrigin};
use super::{layer::ServiceBuilderExt, reconnect::Reconnect, AddOrigin, UserAgent};
use crate::{body::BoxBody, transport::Endpoint};
use http::Uri;
use hyper::client::conn::Builder;
Expand Down Expand Up @@ -55,6 +55,7 @@ impl Connection {

let stack = ServiceBuilder::new()
.layer_fn(|s| AddOrigin::new(s, endpoint.uri.clone()))
.layer_fn(|s| UserAgent::new(s, endpoint.user_agent.clone()))
.optional_layer(endpoint.timeout.map(TimeoutLayer::new))
.optional_layer(endpoint.concurrency_limit.map(ConcurrencyLimitLayer::new))
.optional_layer(endpoint.rate_limit.map(|(l, d)| RateLimitLayer::new(l, d)))
Expand Down
2 changes: 2 additions & 0 deletions tonic/src/transport/service/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ mod reconnect;
mod router;
#[cfg(feature = "tls")]
mod tls;
mod user_agent;

pub(crate) use self::add_origin::AddOrigin;
pub(crate) use self::connection::Connection;
Expand All @@ -18,3 +19,4 @@ pub(crate) use self::layer::ServiceBuilderExt;
pub(crate) use self::router::{Or, Routes};
#[cfg(feature = "tls")]
pub(crate) use self::tls::{TlsAcceptor, TlsConnector};
pub(crate) use self::user_agent::UserAgent;
70 changes: 70 additions & 0 deletions tonic/src/transport/service/user_agent.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
use http::{header::USER_AGENT, HeaderValue, Request};
use std::task::{Context, Poll};
use tower_service::Service;

const TONIC_USER_AGENT: &str = concat!("tonic/", env!("CARGO_PKG_VERSION"));

#[derive(Debug)]
pub(crate) struct UserAgent<T> {
inner: T,
user_agent: HeaderValue,
}

impl<T> UserAgent<T> {
pub(crate) fn new(inner: T, user_agent: Option<HeaderValue>) -> Self {
let user_agent = user_agent
.map(|value| {
let mut buf = Vec::new();
buf.extend(value.as_bytes());
buf.push(b' ');
buf.extend(TONIC_USER_AGENT.as_bytes());
HeaderValue::from_bytes(&buf).expect("user-agent should be valid")
})
.unwrap_or(HeaderValue::from_static(TONIC_USER_AGENT));

Self { inner, user_agent }
}
}

impl<T, ReqBody> Service<Request<ReqBody>> for UserAgent<T>
where
T: Service<Request<ReqBody>>,
{
type Response = T::Response;
type Error = T::Error;
type Future = T::Future;

fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}

fn call(&mut self, mut req: Request<ReqBody>) -> Self::Future {
req.headers_mut()
.insert(USER_AGENT, self.user_agent.clone());

self.inner.call(req)
}
}

#[cfg(test)]
mod tests {
use super::*;

struct Svc;

#[test]
fn sets_default_if_no_custom_user_agent() {
assert_eq!(
UserAgent::new(Svc, None).user_agent,
HeaderValue::from_static(TONIC_USER_AGENT)
)
}

#[test]
fn prepends_custom_user_agent_to_default() {
assert_eq!(
UserAgent::new(Svc, Some(HeaderValue::from_static("Greeter 1.1"))).user_agent,
HeaderValue::from_str(&format!("Greeter 1.1 {}", TONIC_USER_AGENT)).unwrap()
)
}
}