From 908b45cf5ec97c9297e7c17f3e8ccaab1567edd0 Mon Sep 17 00:00:00 2001 From: Martin Algesten Date: Tue, 8 Oct 2024 17:27:11 +0200 Subject: [PATCH] RequestBuilder::query to add query parameters --- Cargo.lock | 1 + Cargo.toml | 1 + src/lib.rs | 1 + src/query.rs | 147 +++++++++++++++++++++++++++++++++++++++++++++++++ src/request.rs | 128 ++++++++++++++++++++++++++++++++++++++++-- 5 files changed, 274 insertions(+), 4 deletions(-) create mode 100644 src/query.rs diff --git a/Cargo.lock b/Cargo.lock index 7e94e3f4..e4474cdf 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1014,6 +1014,7 @@ dependencies = [ "log", "native-tls", "once_cell", + "percent-encoding", "rustls", "rustls-pemfile", "rustls-pki-types", diff --git a/Cargo.toml b/Cargo.toml index 9e14a66c..9e606cb2 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -42,6 +42,7 @@ http = "1.1.0" log = "0.4.22" once_cell = "1.19.0" utf-8 = "0.7.6" +percent-encoding = "2.3.1" # These are used regardless of TLS implementation. rustls-pemfile = { version = "2.1.2", optional = true, default-features = false, features = ["std"] } diff --git a/src/lib.rs b/src/lib.rs index 5e49fcca..28a3dd20 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -377,6 +377,7 @@ mod config; mod error; mod pool; mod proxy; +mod query; mod request; mod run; mod send_body; diff --git a/src/query.rs b/src/query.rs new file mode 100644 index 00000000..822ce3f4 --- /dev/null +++ b/src/query.rs @@ -0,0 +1,147 @@ +use std::borrow::Cow; +use std::fmt; +use std::iter::Enumerate; +use std::ops::Deref; +use std::str::Chars; + +use percent_encoding::utf8_percent_encode; + +#[derive(Clone)] +pub(crate) struct QueryParam<'a> { + source: Source<'a>, +} + +#[derive(Clone)] +enum Source<'a> { + Borrowed(&'a str), + Owned(String), +} + +fn enc(i: &str) -> Cow { + utf8_percent_encode(i, percent_encoding::NON_ALPHANUMERIC).into() +} + +impl<'a> QueryParam<'a> { + pub fn new_key_value(param: &str, value: &str) -> QueryParam<'static> { + let s = format!("{}={}", enc(param), enc(value)); + QueryParam { + source: Source::Owned(s), + } + } + + fn as_str(&self) -> &str { + match &self.source { + Source::Borrowed(v) => v, + Source::Owned(v) => v.as_str(), + } + } +} + +pub(crate) fn parse_query_params(query_string: &str) -> impl Iterator> { + assert!(query_string.is_ascii()); + QueryParamIterator(query_string, query_string.chars().enumerate()) +} + +struct QueryParamIterator<'a>(&'a str, Enumerate>); + +impl<'a> Iterator for QueryParamIterator<'a> { + type Item = QueryParam<'a>; + + fn next(&mut self) -> Option { + let mut first = None; + let mut value = None; + let mut separator = None; + + while let Some((n, c)) = self.1.next() { + if first.is_none() { + first = Some(n); + } + if value.is_none() && c == '=' { + value = Some(n + 1); + } + if c == '&' { + separator = Some(n); + break; + } + } + + if let Some(start) = first { + let end = separator.unwrap_or(self.0.len()); + let chunk = &self.0[start..end]; + return Some(QueryParam { + source: Source::Borrowed(chunk), + }); + } + + None + } +} + +impl<'a> fmt::Debug for QueryParam<'a> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_tuple("QueryParam").field(&self.as_str()).finish() + } +} + +impl<'a> fmt::Display for QueryParam<'a> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match &self.source { + Source::Borrowed(v) => write!(f, "{}", v), + Source::Owned(v) => write!(f, "{}", v), + } + } +} + +impl<'a> Deref for QueryParam<'a> { + type Target = str; + + fn deref(&self) -> &Self::Target { + self.as_str() + } +} + +impl<'a> PartialEq for QueryParam<'a> { + fn eq(&self, other: &Self) -> bool { + self.as_str() == other.as_str() + } +} + +#[cfg(test)] +mod test { + use super::*; + + use http::Uri; + + #[test] + fn query_string_does_not_start_with_question_mark() { + let u: Uri = "https://foo.com/qwe?abc=qwe".parse().unwrap(); + assert_eq!(u.query(), Some("abc=qwe")); + } + + #[test] + fn percent_encoding_is_not_decoded() { + let u: Uri = "https://foo.com/qwe?abc=%20123".parse().unwrap(); + assert_eq!(u.query(), Some("abc=%20123")); + } + + #[test] + fn fragments_are_not_a_thing() { + let u: Uri = "https://foo.com/qwe?abc=qwe#yaz".parse().unwrap(); + assert_eq!(u.to_string(), "https://foo.com/qwe?abc=qwe"); + } + + fn p(s: &str) -> Vec { + parse_query_params(s).map(|q| q.to_string()).collect() + } + + #[test] + fn parse_query_string() { + assert_eq!(parse_query_params("").next(), None); + assert_eq!(p("&"), vec![""]); + assert_eq!(p("="), vec!["="]); + assert_eq!(p("&="), vec!["", "="]); + assert_eq!(p("foo=bar"), vec!["foo=bar"]); + assert_eq!(p("foo=bar&"), vec!["foo=bar"]); + assert_eq!(p("foo=bar&foo2=bar2"), vec!["foo=bar", "foo2=bar2"]); + } +} diff --git a/src/request.rs b/src/request.rs index c67331cf..54ba5b18 100644 --- a/src/request.rs +++ b/src/request.rs @@ -7,8 +7,10 @@ use http::{HeaderName, HeaderValue, Method, Request, Response, Uri, Version}; use crate::body::Body; use crate::config::RequestLevelConfig; +use crate::query::{parse_query_params, QueryParam}; use crate::send_body::AsSendBody; use crate::util::private::Private; +use crate::util::UriExt; use crate::{Agent, Config, Error, SendBody, Timeouts}; /// Transparent wrapper around [`http::request::Builder`]. @@ -18,6 +20,7 @@ use crate::{Agent, Config, Error, SendBody, Timeouts}; pub struct RequestBuilder { agent: Agent, builder: http::request::Builder, + query_extra: Vec>, // This is only used in case http::request::Builder contains an error // (such as URL parsing error), and the user wants a `.config()`. @@ -57,6 +60,17 @@ impl RequestBuilder { self } + /// Add a query paramter to the URL. + /// + /// Always appends a new parameter, also when using the name of + /// an already existing one. + /// + /// Using this feature causes an allocation (of a `Vec` holding the parameters). + pub fn query(mut self, key: &str, value: &str) -> Self { + self.query_extra.push(QueryParam::new_key_value(key, value)); + self + } + /// Overrides the URI for this request. /// /// Typically this is set via `ureq::get()` or `Agent::get()`. This @@ -194,6 +208,7 @@ impl RequestBuilder { Self { agent, builder: Request::builder().method(method).uri(uri), + query_extra: vec![], dummy_config: None, _ph: PhantomData, } @@ -210,7 +225,7 @@ impl RequestBuilder { /// ``` pub fn call(self) -> Result, Error> { let request = self.builder.body(())?; - do_call(self.agent, request, SendBody::none()) + do_call(self.agent, request, self.query_extra, SendBody::none()) } } @@ -223,6 +238,7 @@ impl RequestBuilder { Self { agent, builder: Request::builder().method(method).uri(uri), + query_extra: vec![], dummy_config: None, _ph: PhantomData, } @@ -255,7 +271,7 @@ impl RequestBuilder { pub fn send(self, data: impl AsSendBody) -> Result, Error> { let request = self.builder.body(())?; let mut data_ref = data; - do_call(self.agent, request, data_ref.as_body()) + do_call(self.agent, request, self.query_extra, data_ref.as_body()) } /// Send body data as JSON. @@ -285,15 +301,67 @@ impl RequestBuilder { pub fn send_json(self, data: impl serde::ser::Serialize) -> Result, Error> { let request = self.builder.body(())?; let body = SendBody::from_json(&data)?; - do_call(self.agent, request, body) + do_call(self.agent, request, self.query_extra, body) } } -fn do_call(agent: Agent, request: Request<()>, body: SendBody) -> Result, Error> { +fn do_call( + agent: Agent, + mut request: Request<()>, + query_extra: Vec>, + body: SendBody, +) -> Result, Error> { + if !query_extra.is_empty() { + request.uri().ensure_valid_url()?; + request = amend_request_query(request, query_extra.into_iter()); + } let response = agent.run_via_middleware(request, body)?; Ok(response) } +fn amend_request_query( + request: Request<()>, + query_extra: impl Iterator>, +) -> Request<()> { + let (mut parts, body) = request.into_parts(); + let uri = parts.uri; + let mut path = uri.path().to_string(); + let query_existing = parse_query_params(uri.query().unwrap_or("")); + + let mut do_first = true; + + fn append<'a>( + path: &mut String, + do_first: &mut bool, + iter: impl Iterator>, + ) { + for q in iter { + if *do_first { + *do_first = false; + path.push('?'); + } else { + path.push('&'); + } + path.push_str(&q); + } + } + + append(&mut path, &mut do_first, query_existing); + append(&mut path, &mut do_first, query_extra); + + // Unwraps are OK, because we had a correct URI to begin with + let rebuild = Uri::builder() + .scheme(uri.scheme().unwrap().clone()) + .authority(uri.authority().unwrap().clone()) + .path_and_query(path) + .build() + .unwrap(); + + parts.uri = rebuild; + + Request::from_parts(parts, body) +} + impl Deref for RequestBuilder { type Target = http::request::Builder; @@ -368,4 +436,56 @@ mod test { let mut req = get("http://x.y.z/ borked url"); req.timeouts().global = Some(Duration::from_millis(1)); } + + #[test] + fn add_params_to_request_without_query() { + let request = Request::builder() + .uri("https://foo.bar/path") + .body(()) + .unwrap(); + + let amended = amend_request_query( + request, + vec![ + QueryParam::new_key_value("x", "z"), + QueryParam::new_key_value("ab", "cde"), + ] + .into_iter(), + ); + + assert_eq!(amended.uri(), "https://foo.bar/path?x=z&ab=cde"); + } + + #[test] + fn add_params_to_request_with_query() { + let request = Request::builder() + .uri("https://foo.bar/path?x=z") + .body(()) + .unwrap(); + + let amended = amend_request_query( + request, + vec![QueryParam::new_key_value("ab", "cde")].into_iter(), + ); + + assert_eq!(amended.uri(), "https://foo.bar/path?x=z&ab=cde"); + } + + #[test] + fn add_params_that_need_percent_encoding() { + let request = Request::builder() + .uri("https://foo.bar/path") + .body(()) + .unwrap(); + + let amended = amend_request_query( + request, + vec![QueryParam::new_key_value("å ", "i åa ä e ö")].into_iter(), + ); + + assert_eq!( + amended.uri(), + "https://foo.bar/path?%C3%A5%20=i%20%C3%A5a%20%C3%A4%20e%20%C3%B6" + ); + } }