Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix/query param #843

Merged
merged 2 commits into from
Oct 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ http = "1.1.0"
log = "0.4.22"
once_cell = "1.19.0"
utf-8 = "0.7.6"
percent-encoding = "2.3.1"

# These are used regardless of TLS implementation.
rustls-pemfile = { version = "2.1.2", optional = true, default-features = false, features = ["std"] }
Expand Down
10 changes: 6 additions & 4 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,3 @@
#![forbid(unsafe_code)]
#![warn(clippy::all)]
#![deny(missing_docs)]
//!<div align="center">
//! <!-- Version -->
//! <a href="https://crates.io/crates/ureq">
Expand Down Expand Up @@ -352,7 +349,11 @@
//! let resp = agent.get("http://cool.server").call()?;
//! # Ok(())}
//! ```
//!

#![forbid(unsafe_code)]
#![warn(clippy::all)]
#![deny(missing_docs)]

#[macro_use]
extern crate log;

Expand All @@ -376,6 +377,7 @@ mod config;
mod error;
mod pool;
mod proxy;
mod query;
mod request;
mod run;
mod send_body;
Expand Down
147 changes: 147 additions & 0 deletions src/query.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
use std::borrow::Cow;
use std::fmt;
use std::iter::Enumerate;
use std::ops::Deref;
use std::str::Chars;

use percent_encoding::utf8_percent_encode;

#[derive(Clone)]
pub(crate) struct QueryParam<'a> {
source: Source<'a>,
}

#[derive(Clone)]
enum Source<'a> {
Borrowed(&'a str),
Owned(String),
}

fn enc(i: &str) -> Cow<str> {
utf8_percent_encode(i, percent_encoding::NON_ALPHANUMERIC).into()
}

impl<'a> QueryParam<'a> {
pub fn new_key_value(param: &str, value: &str) -> QueryParam<'static> {
let s = format!("{}={}", enc(param), enc(value));
QueryParam {
source: Source::Owned(s),
}
}

fn as_str(&self) -> &str {
match &self.source {
Source::Borrowed(v) => v,
Source::Owned(v) => v.as_str(),
}
}
}

pub(crate) fn parse_query_params(query_string: &str) -> impl Iterator<Item = QueryParam<'_>> {
assert!(query_string.is_ascii());
QueryParamIterator(query_string, query_string.chars().enumerate())
}

struct QueryParamIterator<'a>(&'a str, Enumerate<Chars<'a>>);

impl<'a> Iterator for QueryParamIterator<'a> {
type Item = QueryParam<'a>;

fn next(&mut self) -> Option<Self::Item> {
let mut first = None;
let mut value = None;
let mut separator = None;

for (n, c) in self.1.by_ref() {
if first.is_none() {
first = Some(n);
}
if value.is_none() && c == '=' {
value = Some(n + 1);
}
if c == '&' {
separator = Some(n);
break;
}
}

if let Some(start) = first {
let end = separator.unwrap_or(self.0.len());
let chunk = &self.0[start..end];
return Some(QueryParam {
source: Source::Borrowed(chunk),
});
}

None
}
}

impl<'a> fmt::Debug for QueryParam<'a> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_tuple("QueryParam").field(&self.as_str()).finish()
}
}

impl<'a> fmt::Display for QueryParam<'a> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match &self.source {
Source::Borrowed(v) => write!(f, "{}", v),
Source::Owned(v) => write!(f, "{}", v),
}
}
}

impl<'a> Deref for QueryParam<'a> {
type Target = str;

fn deref(&self) -> &Self::Target {
self.as_str()
}
}

impl<'a> PartialEq for QueryParam<'a> {
fn eq(&self, other: &Self) -> bool {
self.as_str() == other.as_str()
}
}

#[cfg(test)]
mod test {
use super::*;

use http::Uri;

#[test]
fn query_string_does_not_start_with_question_mark() {
let u: Uri = "https://foo.com/qwe?abc=qwe".parse().unwrap();
assert_eq!(u.query(), Some("abc=qwe"));
}

#[test]
fn percent_encoding_is_not_decoded() {
let u: Uri = "https://foo.com/qwe?abc=%20123".parse().unwrap();
assert_eq!(u.query(), Some("abc=%20123"));
}

#[test]
fn fragments_are_not_a_thing() {
let u: Uri = "https://foo.com/qwe?abc=qwe#yaz".parse().unwrap();
assert_eq!(u.to_string(), "https://foo.com/qwe?abc=qwe");
}

fn p(s: &str) -> Vec<String> {
parse_query_params(s).map(|q| q.to_string()).collect()
}

#[test]
fn parse_query_string() {
assert_eq!(parse_query_params("").next(), None);
assert_eq!(p("&"), vec![""]);
assert_eq!(p("="), vec!["="]);
assert_eq!(p("&="), vec!["", "="]);
assert_eq!(p("foo=bar"), vec!["foo=bar"]);
assert_eq!(p("foo=bar&"), vec!["foo=bar"]);
assert_eq!(p("foo=bar&foo2=bar2"), vec!["foo=bar", "foo2=bar2"]);
}
}
138 changes: 134 additions & 4 deletions src/request.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,10 @@ use http::{HeaderName, HeaderValue, Method, Request, Response, Uri, Version};

use crate::body::Body;
use crate::config::RequestLevelConfig;
use crate::query::{parse_query_params, QueryParam};
use crate::send_body::AsSendBody;
use crate::util::private::Private;
use crate::util::UriExt;
use crate::{Agent, Config, Error, SendBody, Timeouts};

/// Transparent wrapper around [`http::request::Builder`].
Expand All @@ -18,6 +20,7 @@ use crate::{Agent, Config, Error, SendBody, Timeouts};
pub struct RequestBuilder<B> {
agent: Agent,
builder: http::request::Builder,
query_extra: Vec<QueryParam<'static>>,

// This is only used in case http::request::Builder contains an error
// (such as URL parsing error), and the user wants a `.config()`.
Expand Down Expand Up @@ -57,6 +60,27 @@ impl<Any> RequestBuilder<Any> {
self
}

/// Add a query paramter to the URL.
///
/// Always appends a new parameter, also when using the name of
/// an already existing one.
///
/// # Examples
///
/// ```
/// let req = ureq::get("https://httpbin.org/get")
/// .query("my_query", "with_value");
/// ```
pub fn query<K, V>(mut self, key: K, value: V) -> Self
where
K: AsRef<str>,
V: AsRef<str>,
{
self.query_extra
.push(QueryParam::new_key_value(key.as_ref(), value.as_ref()));
self
}

/// Overrides the URI for this request.
///
/// Typically this is set via `ureq::get(<uri>)` or `Agent::get(<uri>)`. This
Expand Down Expand Up @@ -194,6 +218,7 @@ impl RequestBuilder<WithoutBody> {
Self {
agent,
builder: Request::builder().method(method).uri(uri),
query_extra: vec![],
dummy_config: None,
_ph: PhantomData,
}
Expand All @@ -210,7 +235,7 @@ impl RequestBuilder<WithoutBody> {
/// ```
pub fn call(self) -> Result<Response<Body>, Error> {
let request = self.builder.body(())?;
do_call(self.agent, request, SendBody::none())
do_call(self.agent, request, self.query_extra, SendBody::none())
}
}

Expand All @@ -223,6 +248,7 @@ impl RequestBuilder<WithBody> {
Self {
agent,
builder: Request::builder().method(method).uri(uri),
query_extra: vec![],
dummy_config: None,
_ph: PhantomData,
}
Expand Down Expand Up @@ -255,7 +281,7 @@ impl RequestBuilder<WithBody> {
pub fn send(self, data: impl AsSendBody) -> Result<Response<Body>, Error> {
let request = self.builder.body(())?;
let mut data_ref = data;
do_call(self.agent, request, data_ref.as_body())
do_call(self.agent, request, self.query_extra, data_ref.as_body())
}

/// Send body data as JSON.
Expand Down Expand Up @@ -285,15 +311,67 @@ impl RequestBuilder<WithBody> {
pub fn send_json(self, data: impl serde::ser::Serialize) -> Result<Response<Body>, Error> {
let request = self.builder.body(())?;
let body = SendBody::from_json(&data)?;
do_call(self.agent, request, body)
do_call(self.agent, request, self.query_extra, body)
}
}

fn do_call(agent: Agent, request: Request<()>, body: SendBody) -> Result<Response<Body>, Error> {
fn do_call(
agent: Agent,
mut request: Request<()>,
query_extra: Vec<QueryParam<'static>>,
body: SendBody,
) -> Result<Response<Body>, Error> {
if !query_extra.is_empty() {
request.uri().ensure_valid_url()?;
request = amend_request_query(request, query_extra.into_iter());
}
let response = agent.run_via_middleware(request, body)?;
Ok(response)
}

fn amend_request_query(
request: Request<()>,
query_extra: impl Iterator<Item = QueryParam<'static>>,
) -> Request<()> {
let (mut parts, body) = request.into_parts();
let uri = parts.uri;
let mut path = uri.path().to_string();
let query_existing = parse_query_params(uri.query().unwrap_or(""));

let mut do_first = true;

fn append<'a>(
path: &mut String,
do_first: &mut bool,
iter: impl Iterator<Item = QueryParam<'a>>,
) {
for q in iter {
if *do_first {
*do_first = false;
path.push('?');
} else {
path.push('&');
}
path.push_str(&q);
}
}

append(&mut path, &mut do_first, query_existing);
append(&mut path, &mut do_first, query_extra);

// Unwraps are OK, because we had a correct URI to begin with
let rebuild = Uri::builder()
.scheme(uri.scheme().unwrap().clone())
.authority(uri.authority().unwrap().clone())
.path_and_query(path)
.build()
.unwrap();

parts.uri = rebuild;

Request::from_parts(parts, body)
}

impl<MethodLimit> Deref for RequestBuilder<MethodLimit> {
type Target = http::request::Builder;

Expand Down Expand Up @@ -368,4 +446,56 @@ mod test {
let mut req = get("http://x.y.z/ borked url");
req.timeouts().global = Some(Duration::from_millis(1));
}

#[test]
fn add_params_to_request_without_query() {
let request = Request::builder()
.uri("https://foo.bar/path")
.body(())
.unwrap();

let amended = amend_request_query(
request,
vec![
QueryParam::new_key_value("x", "z"),
QueryParam::new_key_value("ab", "cde"),
]
.into_iter(),
);

assert_eq!(amended.uri(), "https://foo.bar/path?x=z&ab=cde");
}

#[test]
fn add_params_to_request_with_query() {
let request = Request::builder()
.uri("https://foo.bar/path?x=z")
.body(())
.unwrap();

let amended = amend_request_query(
request,
vec![QueryParam::new_key_value("ab", "cde")].into_iter(),
);

assert_eq!(amended.uri(), "https://foo.bar/path?x=z&ab=cde");
}

#[test]
fn add_params_that_need_percent_encoding() {
let request = Request::builder()
.uri("https://foo.bar/path")
.body(())
.unwrap();

let amended = amend_request_query(
request,
vec![QueryParam::new_key_value("å ", "i åa ä e ö")].into_iter(),
);

assert_eq!(
amended.uri(),
"https://foo.bar/path?%C3%A5%20=i%20%C3%A5a%20%C3%A4%20e%20%C3%B6"
);
}
}
Loading