Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(transport): add user-agent header to client requests. #457

Merged
merged 6 commits into from
Sep 23, 2020
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions tonic/src/transport/channel/endpoint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use crate::transport::service::TlsConnector;
use crate::transport::Error;
use bytes::Bytes;
use http::uri::{InvalidUri, Uri};
use http::HeaderValue;
alce marked this conversation as resolved.
Show resolved Hide resolved
use std::{
convert::{TryFrom, TryInto},
fmt,
Expand All @@ -20,6 +21,7 @@ use tower_make::MakeConnection;
#[derive(Clone)]
pub struct Endpoint {
pub(crate) uri: Uri,
pub(crate) user_agent: Option<HeaderValue>,
pub(crate) timeout: Option<Duration>,
pub(crate) concurrency_limit: Option<usize>,
pub(crate) rate_limit: Option<(u64, Duration)>,
Expand Down Expand Up @@ -74,6 +76,30 @@ impl Endpoint {
Ok(Self::from(uri))
}

/// Set a custom user-agent header.
///
/// `user_agent` will be prepended to Tonic's default user-agent string (`tonic/x.x.x`).
/// It must be a value that can be converted into a valid `http::HeaderValue` or building
/// the endpoint will fail.
/// ```
/// # use tonic::transport::Endpoint;
/// # let mut builder = Endpoint::from_static("https://example.com");
/// builder.user_agent("Greeter").expect("Greeter should be a valid header value");
/// // user-agent: "Greeter tonic/x.x.x"
/// ```
pub fn user_agent<T>(self, user_agent: T) -> Result<Self, Error>
where
T: TryInto<HeaderValue>,
{
user_agent
.try_into()
.map(|ua| Endpoint {
user_agent: Some(ua),
..self
})
.map_err(|_| Error::new_invalid_user_agent())
}

/// Apply a timeout to each request.
///
/// ```
Expand Down Expand Up @@ -276,6 +302,7 @@ impl From<Uri> for Endpoint {
fn from(uri: Uri) -> Self {
Self {
uri,
user_agent: None,
concurrency_limit: None,
rate_limit: None,
timeout: None,
Expand Down
6 changes: 6 additions & 0 deletions tonic/src/transport/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ struct ErrorImpl {
pub(crate) enum Kind {
Transport,
InvalidUri,
InvalidUserAgent,
}

impl Error {
Expand All @@ -43,10 +44,15 @@ impl Error {
Error::new(Kind::InvalidUri)
}

pub(crate) fn new_invalid_user_agent() -> Self {
Error::new(Kind::InvalidUserAgent)
}

fn description(&self) -> &str {
match &self.inner.kind {
Kind::Transport => "transport error",
Kind::InvalidUri => "invalid URI",
Kind::InvalidUserAgent => "user agent is not a valid header value",
}
}
}
Expand Down
3 changes: 2 additions & 1 deletion tonic/src/transport/service/connection.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use super::{layer::ServiceBuilderExt, reconnect::Reconnect, AddOrigin};
use super::{layer::ServiceBuilderExt, reconnect::Reconnect, AddOrigin, UserAgent};
use crate::{body::BoxBody, transport::Endpoint};
use http::Uri;
use hyper::client::conn::Builder;
Expand Down Expand Up @@ -55,6 +55,7 @@ impl Connection {

let stack = ServiceBuilder::new()
.layer_fn(|s| AddOrigin::new(s, endpoint.uri.clone()))
.layer_fn(|s| UserAgent::new(s, endpoint.user_agent.clone()))
.optional_layer(endpoint.timeout.map(TimeoutLayer::new))
.optional_layer(endpoint.concurrency_limit.map(ConcurrencyLimitLayer::new))
.optional_layer(endpoint.rate_limit.map(|(l, d)| RateLimitLayer::new(l, d)))
Expand Down
2 changes: 2 additions & 0 deletions tonic/src/transport/service/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ mod reconnect;
mod router;
#[cfg(feature = "tls")]
mod tls;
mod user_agent;

pub(crate) use self::add_origin::AddOrigin;
pub(crate) use self::connection::Connection;
Expand All @@ -18,3 +19,4 @@ pub(crate) use self::layer::ServiceBuilderExt;
pub(crate) use self::router::{Or, Routes};
#[cfg(feature = "tls")]
pub(crate) use self::tls::{TlsAcceptor, TlsConnector};
pub(crate) use self::user_agent::UserAgent;
47 changes: 47 additions & 0 deletions tonic/src/transport/service/user_agent.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
use http::{header::USER_AGENT, HeaderValue, Request};
use std::task::{Context, Poll};
use tower_service::Service;

const TONIC_USER_AGENT: &str = concat!("tonic/", env!("CARGO_PKG_VERSION"));

#[derive(Debug)]
pub(crate) struct UserAgent<T> {
inner: T,
user_agent: Option<HeaderValue>,
}

impl<T> UserAgent<T> {
pub(crate) fn new(inner: T, user_agent: Option<HeaderValue>) -> Self {
Self { inner, user_agent }
}
}

impl<T, ReqBody> Service<Request<ReqBody>> for UserAgent<T>
where
T: Service<Request<ReqBody>>,
{
type Response = T::Response;
type Error = T::Error;
type Future = T::Future;

fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}

fn call(&mut self, mut req: Request<ReqBody>) -> Self::Future {
let value = match self.user_agent.as_ref() {
Some(custom_ua) => {
alce marked this conversation as resolved.
Show resolved Hide resolved
let mut buf = Vec::new();
buf.extend(custom_ua.as_bytes());
buf.push(b' ');
buf.extend(TONIC_USER_AGENT.as_bytes());
HeaderValue::from_bytes(&buf).expect("user-agent should be valid")
}
None => HeaderValue::from_static(TONIC_USER_AGENT),
};

req.headers_mut().insert(USER_AGENT, value);

self.inner.call(req)
}
}