diff --git a/Cargo.lock b/Cargo.lock index 92f8a272..72572ed1 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -325,6 +325,15 @@ version = "1.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a" +[[package]] +name = "block-buffer" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4152116fd6e9dadb291ae18fc1ec3575ed6d84c29642d97890f4b4a3417297e4" +dependencies = [ + "generic-array", +] + [[package]] name = "blocking" version = "1.0.2" @@ -531,6 +540,15 @@ version = "0.8.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ea221b5284a47e40033bf9b66f35f984ec0ea2931eb03505246cd27a963f981b" +[[package]] +name = "cpufeatures" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "95059428f66df56b63431fdb4e1947ed2190586af5c5a8a8b71122bdf5a7f469" +dependencies = [ + "libc", +] + [[package]] name = "crc32fast" version = "1.2.1" @@ -615,6 +633,28 @@ version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6184e33543162437515c2e2b48714794e37845ec9851711914eec9d308f6ebe8" +[[package]] +name = "digest" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d3dd60d1080a57a05ab032377049e0591415d2b31afd7028356dbf3cc6dcb066" +dependencies = [ + "generic-array", +] + +[[package]] +name = "digest_auth" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fa30657988b2ced88f68fe490889e739bf98d342916c33ed3100af1d6f1cbc9c" +dependencies = [ + "digest", + "hex", + "md-5", + "rand", + "sha2", +] + [[package]] name = "dirs" version = "3.0.2" @@ -860,6 +900,16 @@ dependencies = [ "slab", ] +[[package]] +name = "generic-array" +version = "0.14.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "501466ecc8a30d1d3b7fc9229b122b2ce8ed6e9d9223f1138d4babb253e51817" +dependencies = [ + "typenum", + "version_check", +] + [[package]] name = "getopts" version = "0.2.21" @@ -936,6 +986,12 @@ dependencies = [ "libc", ] +[[package]] +name = "hex" +version = "0.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7f24254aa9a54b5c858eaee2f5bccdb46aaf0e486a595ed5fd8f86ba55232a70" + [[package]] name = "http" version = "0.2.4" @@ -1290,6 +1346,17 @@ version = "0.1.9" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a3e378b66a060d48947b590737b30a1be76706c8dd7b8ba0f2fe3989c68a853f" +[[package]] +name = "md-5" +version = "0.9.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7b5a279bb9607f9f53c22d496eade00d138d1bdcccd07d74650387cf94942a15" +dependencies = [ + "block-buffer", + "digest", + "opaque-debug", +] + [[package]] name = "memchr" version = "2.4.1" @@ -1458,6 +1525,12 @@ dependencies = [ "pkg-config", ] +[[package]] +name = "opaque-debug" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "624a8340c38c1b80fd549087862da4ba43e08858af025b236e509b6649fc13d5" + [[package]] name = "openssl" version = "0.10.36" @@ -2086,6 +2159,19 @@ version = "0.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2579985fda508104f7587689507983eadd6a6e84dd35d6d115361f530916fa0d" +[[package]] +name = "sha2" +version = "0.9.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b69f9a4c9740d74c5baa3fd2e547f9525fa8088a8a958e0ca2409a514e33f5fa" +dependencies = [ + "block-buffer", + "cfg-if", + "cpufeatures", + "digest", + "opaque-debug", +] + [[package]] name = "shell-escape" version = "0.1.5" @@ -2567,6 +2653,12 @@ version = "0.2.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "59547bce71d9c38b83d9c0e92b6066c4253371f15005def0c30d9657f50c7642" +[[package]] +name = "typenum" +version = "1.14.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b63708a265f51345575b27fe43f9500ad611579e764c79edbc2037b1121959ec" + [[package]] name = "unicase" version = "2.6.0" @@ -2860,6 +2952,7 @@ dependencies = [ "cookie 0.15.1", "cookie_store 0.15.0", "curl", + "digest_auth", "dirs", "encoding_rs", "encoding_rs_io", diff --git a/Cargo.toml b/Cargo.toml index 68b6bb20..9fb423dd 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -18,6 +18,7 @@ base64 = "0.13" bytes = "1.0.1" cookie_crate = { version = "0.15", package = "cookie" } cookie_store = { version = "0.15.0" } +digest_auth = "0.3.0" dirs = "3.0.1" encoding_rs = "0.8.28" encoding_rs_io = "0.1.7" diff --git a/src/auth.rs b/src/auth.rs index bfe50593..d66390ec 100644 --- a/src/auth.rs +++ b/src/auth.rs @@ -1,85 +1,143 @@ use std::env; +use std::fs; use std::io; use std::path::PathBuf; -use crate::regex; -use dirs::home_dir; +use anyhow::Result; use netrc_rs::Netrc; -use std::fs; +use reqwest::blocking::{Request, Response}; +use reqwest::header::{HeaderValue, AUTHORIZATION, WWW_AUTHENTICATE}; +use reqwest::StatusCode; -pub fn parse_auth(auth: String, host: &str) -> io::Result<(String, Option)> { - if let Some(cap) = regex!(r"^([^:]*):$").captures(&auth) { +use crate::cli::AuthType; +use crate::middleware::{Context, Middleware}; +use crate::regex; +use crate::utils::{clone_request, get_home_dir}; + +#[derive(Debug, PartialEq, Eq)] +pub enum Auth { + Bearer(String), + Basic(String, Option), + Digest(String, String), +} + +impl Auth { + pub fn from_str(auth: &str, auth_type: AuthType, host: &str) -> Result { + match auth_type { + AuthType::basic => { + let (username, password) = parse_auth(auth, host)?; + Ok(Auth::Basic(username, password)) + } + AuthType::digest => { + let (username, password) = parse_auth(auth, host)?; + Ok(Auth::Digest( + username, + password.unwrap_or_else(|| "".into()), + )) + } + AuthType::bearer => Ok(Auth::Bearer(auth.into())), + } + } + + pub fn from_netrc(netrc: &str, auth_type: AuthType, host: &str) -> Option { + Netrc::parse_borrow(&netrc, false) + .ok()? + .machines + .into_iter() + .filter_map(|machine| match machine.name { + Some(name) if name == host => { + let username = machine.login.unwrap_or_else(|| "".into()); + let password = machine.password; + match auth_type { + AuthType::basic => Some(Auth::Basic(username, password)), + AuthType::digest => Some(Auth::Digest( + username, + password.unwrap_or_else(|| "".into()), + )), + AuthType::bearer => None, + } + } + _ => None, + }) + .last() + } +} + +pub fn parse_auth(auth: &str, host: &str) -> io::Result<(String, Option)> { + if let Some(cap) = regex!(r"^([^:]*):$").captures(auth) { Ok((cap[1].to_string(), None)) - } else if let Some(cap) = regex!(r"^(.+?):(.+)$").captures(&auth) { + } else if let Some(cap) = regex!(r"^(.+?):(.+)$").captures(auth) { let username = cap[1].to_string(); let password = cap[2].to_string(); Ok((username, Some(password))) } else { - let username = auth; + let username = auth.to_string(); let prompt = format!("http: password for {}@{}: ", username, host); let password = rpassword::read_password_from_tty(Some(&prompt))?; Ok((username, Some(password))) } } -fn get_home_dir() -> Option { - #[cfg(target_os = "windows")] - if let Some(path) = env::var_os("XH_TEST_MODE_WIN_HOME_DIR") { - return Some(PathBuf::from(path)); - } - - home_dir() -} - -fn netrc_path() -> Option { - match env::var_os("NETRC") { +pub fn read_netrc() -> Option { + let netrc_path = match env::var_os("NETRC") { Some(path) => { - let pth = PathBuf::from(path); - if pth.exists() { - Some(pth) + let path = PathBuf::from(path); + if path.exists() { + Some(path) } else { None } } None => { - if let Some(hd_path) = get_home_dir() { - [".netrc", "_netrc"] - .iter() - .map(|f| hd_path.join(f)) - .find(|p| p.exists()) - } else { - None - } + let home_dir = get_home_dir()?; + [".netrc", "_netrc"] + .iter() + .map(|f| home_dir.join(f)) + .find(|p| p.exists()) } - } + }?; + + fs::read_to_string(netrc_path).ok() } -pub fn read_netrc() -> Option { - if let Some(netrc_path) = netrc_path() { - if let Ok(result) = fs::read_to_string(netrc_path) { - return Some(result); - } - }; +pub struct DigestAuthMiddleware<'a> { + username: &'a str, + password: &'a str, +} - None +impl<'a> DigestAuthMiddleware<'a> { + pub fn new(username: &'a str, password: &'a str) -> Self { + DigestAuthMiddleware { username, password } + } } -pub fn auth_from_netrc(machine: &str, netrc: &str) -> Option<(String, Option)> { - if let Ok(netrc) = Netrc::parse_borrow(&netrc, false) { - return netrc - .machines - .into_iter() - .filter_map(|mach| match mach.name { - Some(name) if name == machine => { - let user = mach.login.unwrap_or_else(|| "".to_string()); - Some((user, mach.password)) +impl<'a> Middleware for DigestAuthMiddleware<'a> { + fn handle(&mut self, mut ctx: Context, mut request: Request) -> Result { + let response = self.next(&mut ctx, clone_request(&mut request)?)?; + match response.headers().get(WWW_AUTHENTICATE) { + Some(wwwauth) if response.status() == StatusCode::UNAUTHORIZED => { + let mut context = digest_auth::AuthContext::new( + self.username, + self.password, + request.url().path(), + ); + if let Some(cnonc) = std::env::var_os("XH_TEST_DIGEST_AUTH_CNONCE") { + context.set_custom_cnonce(cnonc.to_string_lossy().to_string()); } - _ => None, - }) - .last(); + let mut prompt = digest_auth::parse(wwwauth.to_str()?)?; + let answer = prompt.respond(&context)?.to_header_string(); + request + .headers_mut() + .insert(AUTHORIZATION, HeaderValue::from_str(&answer)?); + if let Some(url) = std::env::var_os("XH_TEST_DIGEST_AUTH_URL") { + *request.url_mut() = reqwest::Url::parse(&url.to_string_lossy())?; + } + self.print(&mut ctx, response, &mut request)?; + Ok(self.next(&mut ctx, request)?) + } + _ => Ok(response), + } } - - None } #[cfg(test)] @@ -95,7 +153,7 @@ mod tests { (":", ("", None)), ]; for (input, output) in expected { - let (user, pass) = parse_auth(input.to_string(), "").unwrap(); + let (user, pass) = parse_auth(input, "").unwrap(); assert_eq!(output, (user.as_str(), pass.as_deref())); } } @@ -111,24 +169,24 @@ mod tests { ( "example.com", good_netrc, - Some(("user".to_string(), Some("pass".to_string()))), + Some(Auth::Basic("user".to_string(), Some("pass".to_string()))), ), ("example.org", good_netrc, None), ("example.com", malformed_netrc, None), ( "example.com", missing_login, - Some(("".to_string(), Some("pass".to_string()))), + Some(Auth::Basic("".to_string(), Some("pass".to_string()))), ), ( "example.com", missing_pass, - Some(("user".to_string(), None)), + Some(Auth::Basic("user".to_string(), None)), ), ]; for (machine, netrc, output) in expected { - assert_eq!(output, auth_from_netrc(machine, netrc)); + assert_eq!(output, Auth::from_netrc(netrc, AuthType::basic, machine)); } } } diff --git a/src/cli.rs b/src/cli.rs index f7f206cc..cbcaa67f 100644 --- a/src/cli.rs +++ b/src/cli.rs @@ -152,21 +152,22 @@ pub struct Cli { #[structopt(skip)] pub is_session_read_only: bool, - // Currently deprecated in favor of --bearer, un-hide if new auth types are introduced /// Specify the auth mechanism. - #[structopt(short = "A", long, possible_values = &AuthType::variants(), - default_value = "basic", case_insensitive = true, hidden = true)] - pub auth_type: AuthType, + #[structopt(short = "A", long, possible_values = &AuthType::variants(), case_insensitive = true)] + pub auth_type: Option, - /// Authenticate as USER with PASS. PASS will be prompted if missing. + /// Authenticate as USER with PASS or with token. /// - /// Use a trailing colon (i.e. `USER:`) to authenticate with just a username. + /// PASS will be prompted if missing. Use a trailing colon (i.e. `USER:`) + /// to authenticate with just a username. + /// + /// if --auth-type=bearer then --auth expects a token /// {n}{n}{n} - #[structopt(short = "a", long, value_name = "USER[:PASS]")] + #[structopt(short = "a", long, value_name = "USER[:PASS] | token")] pub auth: Option, /// Authenticate with a bearer token. - #[structopt(long, value_name = "TOKEN")] + #[structopt(long, value_name = "TOKEN", hidden = true)] pub bearer: Option, /// Do not use credentials from .netrc @@ -520,8 +521,9 @@ impl Cli { if self.https { self.default_scheme = Some("https".to_string()); } - if self.auth_type == AuthType::bearer && self.auth.is_some() { - self.bearer = self.auth.take(); + if self.bearer.is_some() { + self.auth_type = Some(AuthType::bearer); + self.auth = self.bearer.take(); } self.check_status = match (self.check_status_raw, matches.is_present("no-check-status")) { (true, true) => unreachable!(), @@ -705,7 +707,13 @@ arg_enum! { #[allow(non_camel_case_types)] #[derive(Debug, PartialEq)] pub enum AuthType { - basic, bearer + basic, bearer, digest + } +} + +impl Default for AuthType { + fn default() -> Self { + AuthType::basic } } @@ -746,7 +754,7 @@ impl Theme { } } -#[derive(Debug, Clone)] +#[derive(Debug, Clone, Copy)] pub struct Print { pub request_headers: bool, pub request_body: bool, @@ -1093,29 +1101,6 @@ mod tests { ); } - #[test] - fn auth() { - let cli = parse(&["--auth=user:pass", ":"]).unwrap(); - assert_eq!(cli.auth.as_deref(), Some("user:pass")); - assert_eq!(cli.bearer, None); - - let cli = parse(&["--auth=user:pass", "--auth-type=basic", ":"]).unwrap(); - assert_eq!(cli.auth.as_deref(), Some("user:pass")); - assert_eq!(cli.bearer, None); - - let cli = parse(&["--auth=token", "--auth-type=bearer", ":"]).unwrap(); - assert_eq!(cli.auth, None); - assert_eq!(cli.bearer.as_deref(), Some("token")); - - let cli = parse(&["--bearer=token", "--auth-type=bearer", ":"]).unwrap(); - assert_eq!(cli.auth, None); - assert_eq!(cli.bearer.as_deref(), Some("token")); - - let cli = parse(&["--auth-type=bearer", ":"]).unwrap(); - assert_eq!(cli.auth, None); - assert_eq!(cli.bearer, None); - } - #[test] fn request_type_overrides() { let cli = parse(&["--form", "--json", ":"]).unwrap(); @@ -1313,7 +1298,7 @@ mod tests { ]) .unwrap(); assert_eq!(cli.bearer, None); - assert_eq!(cli.auth_type, AuthType::basic); + assert_eq!(cli.auth_type, None); } #[test] diff --git a/src/main.rs b/src/main.rs index a41f96dc..b0f8117b 100644 --- a/src/main.rs +++ b/src/main.rs @@ -4,6 +4,7 @@ mod buffer; mod cli; mod download; mod formatting; +mod middleware; mod printer; mod redirect; mod request_items; @@ -21,16 +22,17 @@ use std::sync::Arc; use anyhow::{anyhow, Context, Result}; use atty::Stream; +use redirect::RedirectFollower; use reqwest::blocking::Client; use reqwest::header::{ - HeaderValue, ACCEPT, ACCEPT_ENCODING, AUTHORIZATION, CONNECTION, CONTENT_TYPE, COOKIE, RANGE, - USER_AGENT, + HeaderValue, ACCEPT, ACCEPT_ENCODING, CONNECTION, CONTENT_TYPE, COOKIE, RANGE, USER_AGENT, }; -use crate::auth::{auth_from_netrc, parse_auth, read_netrc}; +use crate::auth::{read_netrc, Auth, DigestAuthMiddleware}; use crate::buffer::Buffer; use crate::cli::{BodyType, Cli, HttpVersion, Print, Proxy, Verify}; use crate::download::{download_file, get_file_size}; +use crate::middleware::ClientWithMiddleware; use crate::printer::Printer; use crate::request_items::{Body, FORM_CONTENT_TYPE, JSON_ACCEPT, JSON_CONTENT_TYPE}; use crate::session::Session; @@ -137,6 +139,8 @@ fn run(args: Cli) -> Result { let mut exit_code: i32 = 0; let mut resume: Option = None; + let mut auth = None; + let mut save_auth_in_session = true; if args.url.scheme() == "https" { let verify = args.verify.unwrap_or_else(|| { @@ -243,14 +247,10 @@ fn run(args: Cli) -> Result { }; if let Some(ref mut s) = session { + auth = s.auth()?; for (key, value) in s.headers()?.iter() { headers.entry(key).or_insert_with(|| value.clone()); } - if let Some(auth) = s.auth()? { - headers - .entry(AUTHORIZATION) - .or_insert(HeaderValue::from_str(&auth)?); - } s.save_headers(&headers)?; let mut cookie_jar = cookie_jar.lock().unwrap(); @@ -340,26 +340,33 @@ fn run(args: Cli) -> Result { } } - if let Some(auth) = args.auth { - let (username, password) = parse_auth(auth, args.url.host_str().unwrap_or(""))?; - if let Some(ref mut s) = session { - s.save_basic_auth(username.clone(), password.clone()); - } - request_builder = request_builder.basic_auth(username, password); + let auth_type = args.auth_type.unwrap_or_default(); + if let Some(auth_from_arg) = args.auth { + auth = Some(Auth::from_str( + &auth_from_arg, + auth_type, + args.url.host_str().unwrap_or(""), + )?); } else if !args.ignore_netrc { - if let Some(host) = args.url.host_str() { - if let Some(netrc) = read_netrc() { - if let Some((username, password)) = auth_from_netrc(host, &netrc) { - request_builder = request_builder.basic_auth(username, password); - } - } + if let (Some(host), Some(netrc)) = (args.url.host_str(), read_netrc()) { + auth = Auth::from_netrc(&netrc, auth_type, host); + save_auth_in_session = false; } } - if let Some(token) = args.bearer { + + if let Some(auth) = &auth { if let Some(ref mut s) = session { - s.save_bearer_auth(token.clone()) + if save_auth_in_session { + s.save_auth(auth); + } + } + request_builder = match auth { + Auth::Basic(username, password) => { + request_builder.basic_auth(username, password.as_ref()) + } + Auth::Bearer(token) => request_builder.bearer_auth(token), + Auth::Digest(..) => request_builder, } - request_builder = request_builder.bearer_auth(token); } let mut request = request_builder.headers(headers).build()?; @@ -396,33 +403,50 @@ fn run(args: Cli) -> Result { ), }; let pretty = args.pretty.unwrap_or_else(|| buffer.guess_pretty()); - let mut printer = Printer::new(print.clone(), pretty, args.style, args.stream, buffer); + let mut printer = Printer::new(pretty, args.style, args.stream, buffer); let response_charset = args.response_charset; let response_mime = args.response_mime.as_deref(); - printer.print_request_headers(&request, &*cookie_jar)?; - printer.print_request_body(&mut request)?; + if print.request_headers { + printer.print_request_headers(&request, &*cookie_jar)?; + } + if print.request_body { + printer.print_request_body(&mut request)?; + } if !args.offline { - let response = if args.follow { - let mut client = - redirect::RedirectFollower::new(&client, args.max_redirects.unwrap_or(10)); - if let Some(history_print) = args.history_print { - printer.print = history_print; - } + let response = { + let history_print = args.history_print.unwrap_or(print); + let mut client = ClientWithMiddleware::new(&client); if args.all { - client.on_redirect(|prev_response, next_request| { - printer.print_response_headers(&prev_response)?; - printer.print_response_body(prev_response, response_charset, response_mime)?; - printer.print_separator()?; - printer.print_request_headers(next_request, &*cookie_jar)?; - printer.print_request_body(next_request)?; + client = client.with_printer(|prev_response, next_request| { + if history_print.response_headers { + printer.print_response_headers(&prev_response)?; + } + if history_print.response_body { + printer.print_response_body( + prev_response, + response_charset, + response_mime, + )?; + printer.print_separator()?; + } + if history_print.request_headers { + printer.print_request_headers(next_request, &*cookie_jar)?; + } + if history_print.request_body { + printer.print_request_body(next_request)?; + } Ok(()) }); } - client.execute(request)? - } else { + if args.follow { + client = client.with(RedirectFollower::new(args.max_redirects.unwrap_or(10))); + } + if let Some(Auth::Digest(username, password)) = &auth { + client = client.with(DigestAuthMiddleware::new(username, password)); + } client.execute(request)? }; @@ -439,8 +463,9 @@ fn run(args: Cli) -> Result { warn(&format!("HTTP {}", status)); } - printer.print = print; - printer.print_response_headers(&response)?; + if print.response_headers { + printer.print_response_headers(&response)?; + } if args.download { if exit_code == 0 { download_file( diff --git a/src/middleware.rs b/src/middleware.rs new file mode 100644 index 00000000..14b647c6 --- /dev/null +++ b/src/middleware.rs @@ -0,0 +1,89 @@ +use anyhow::Result; +use reqwest::blocking::{Client, Request, Response}; + +pub struct Context<'a, 'b> { + client: &'a Client, + printer: Option<&'a mut (dyn FnMut(Response, &mut Request) -> Result<()> + 'b)>, + middlewares: &'a mut [Box], +} + +impl<'a, 'b> Context<'a, 'b> { + fn new( + client: &'a Client, + printer: Option<&'a mut (dyn FnMut(Response, &mut Request) -> Result<()> + 'b)>, + middlewares: &'a mut [Box], + ) -> Self { + Context { + client, + printer, + middlewares, + } + } + + fn execute(&mut self, request: Request) -> Result { + match self.middlewares { + [] => Ok(self.client.execute(request)?), + [ref mut head, tail @ ..] => head.handle( + Context::new(self.client, self.printer.as_deref_mut(), tail), + request, + ), + } + } +} + +pub trait Middleware { + fn handle(&mut self, ctx: Context, request: Request) -> Result; + + fn next(&self, ctx: &mut Context, request: Request) -> Result { + ctx.execute(request) + } + + fn print(&self, ctx: &mut Context, response: Response, request: &mut Request) -> Result<()> { + if let Some(ref mut printer) = ctx.printer { + printer(response, request)? + } + + Ok(()) + } +} + +pub struct ClientWithMiddleware<'a, T> +where + T: FnMut(Response, &mut Request) -> Result<()>, +{ + client: &'a Client, + printer: Option, + middlewares: Vec>, +} + +impl<'a, T> ClientWithMiddleware<'a, T> +where + T: FnMut(Response, &mut Request) -> Result<()> + 'a, +{ + pub fn new(client: &'a Client) -> Self { + ClientWithMiddleware { + client, + printer: None, + middlewares: vec![], + } + } + + pub fn with_printer(mut self, printer: T) -> Self { + self.printer = Some(printer); + self + } + + pub fn with(mut self, middleware: impl Middleware + 'a) -> Self { + self.middlewares.push(Box::new(middleware)); + self + } + + pub fn execute(&mut self, request: Request) -> Result { + let mut ctx = Context::new( + self.client, + self.printer.as_mut().map(|p| p as _), + &mut self.middlewares[..], + ); + ctx.execute(request) + } +} diff --git a/src/printer.rs b/src/printer.rs index 2f0c0b56..3bf6aba5 100644 --- a/src/printer.rs +++ b/src/printer.rs @@ -13,7 +13,7 @@ use termcolor::WriteColor; use crate::{ buffer::Buffer, - cli::{Pretty, Print, Theme}, + cli::{Pretty, Theme}, formatting::{get_json_formatter, Highlighter}, utils::{copy_largebuf, test_mode, BUFFER_SIZE}, }; @@ -66,7 +66,6 @@ impl<'a, T: Read> BinaryGuard<'a, T> { } pub struct Printer { - pub print: Print, indent_json: bool, color: bool, theme: Theme, @@ -76,17 +75,10 @@ pub struct Printer { } impl Printer { - pub fn new( - print: Print, - pretty: Pretty, - theme: Option, - stream: bool, - buffer: Buffer, - ) -> Self { + pub fn new(pretty: Pretty, theme: Option, stream: bool, buffer: Buffer) -> Self { let theme = theme.unwrap_or(Theme::auto); Printer { - print, indent_json: pretty.format(), sort_headers: pretty.format(), color: pretty.color() && (cfg!(test) || buffer.supports_color()), @@ -291,13 +283,8 @@ impl Printer { header_string } - // Each of the print_* functions adds an extra line separator at the end - // except for print_response_body. We are using this function when we have - // something to print after the response body. pub fn print_separator(&mut self) -> io::Result<()> { - if self.print.response_body { - self.buffer.print("\n")?; - } + self.buffer.print("\n")?; Ok(()) } @@ -305,10 +292,6 @@ impl Printer { where T: CookieStore, { - if !self.print.request_headers { - return Ok(()); - } - let method = request.method(); let url = request.url(); let query_string = url.query().map_or(String::from(""), |q| ["?", q].concat()); @@ -357,10 +340,6 @@ impl Printer { } pub fn print_response_headers(&mut self, response: &Response) -> io::Result<()> { - if !self.print.response_headers { - return Ok(()); - } - let version = response.version(); let status = response.status(); let headers = response.headers(); @@ -374,10 +353,6 @@ impl Printer { } pub fn print_request_body(&mut self, request: &mut Request) -> anyhow::Result<()> { - if !self.print.request_body { - return Ok(()); - } - let content_type = get_content_type(request.headers()); if let Some(body) = request.body_mut() { let body = body.buffer()?; @@ -399,10 +374,6 @@ impl Printer { encoding: Option<&'static Encoding>, mime: Option<&str>, ) -> anyhow::Result<()> { - if !self.print.response_body { - return Ok(()); - } - let content_type = mime .map(ContentType::from) .unwrap_or_else(|| get_content_type(response.headers())); @@ -569,7 +540,7 @@ mod tests { let buffer = Buffer::new(args.download, args.output.as_deref(), is_stdout_tty, None).unwrap(); let pretty = args.pretty.unwrap_or_else(|| buffer.guess_pretty()); - Printer::new("hHbB".parse().unwrap(), pretty, args.style, false, buffer) + Printer::new(pretty, args.style, false, buffer) } fn temp_path() -> String { @@ -651,7 +622,6 @@ mod tests { #[test] fn test_header_casing() { let p = Printer { - print: "hHbB".parse().unwrap(), indent_json: false, color: false, theme: Theme::auto, diff --git a/src/redirect.rs b/src/redirect.rs index c17d442a..70c8dc98 100644 --- a/src/redirect.rs +++ b/src/redirect.rs @@ -1,5 +1,4 @@ use anyhow::{anyhow, Result}; -use reqwest::blocking::Client; use reqwest::blocking::{Request, Response}; use reqwest::header::{ HeaderMap, AUTHORIZATION, CONTENT_ENCODING, CONTENT_LENGTH, CONTENT_TYPE, COOKIE, LOCATION, @@ -7,36 +6,25 @@ use reqwest::header::{ }; use reqwest::{Method, StatusCode, Url}; -pub struct RedirectFollower<'a, T> -where - T: FnMut(Response, &mut Request) -> Result<()>, -{ - client: &'a Client, +use crate::middleware::{Context, Middleware}; +use crate::utils::clone_request; + +pub struct RedirectFollower { max_redirects: usize, - callback: Option, } -impl<'a, T> RedirectFollower<'a, T> -where - T: FnMut(Response, &mut Request) -> Result<()>, -{ - pub fn new(client: &'a Client, max_redirects: usize) -> Self { - RedirectFollower { - client, - max_redirects, - callback: None, - } - } - - pub fn on_redirect(&mut self, callback: T) { - self.callback = Some(callback); +impl RedirectFollower { + pub fn new(max_redirects: usize) -> Self { + RedirectFollower { max_redirects } } +} - pub fn execute(&mut self, mut first_request: Request) -> Result { +impl Middleware for RedirectFollower { + fn handle(&mut self, mut ctx: Context, mut first_request: Request) -> Result { // This buffers the body in case we need it again later // reqwest does *not* do this, it ignores 307/308 with a streaming body let mut request = clone_request(&mut first_request)?; - let mut response = self.client.execute(first_request)?; + let mut response = self.next(&mut ctx, first_request)?; let mut remaining_redirects = self.max_redirects - 1; while let Some(mut next_request) = get_next_request(request, &response) { @@ -48,26 +36,15 @@ where self.max_redirects )); } - if let Some(ref mut callback) = self.callback { - callback(response, &mut next_request)?; - } + self.print(&mut ctx, response, &mut next_request)?; request = clone_request(&mut next_request)?; - response = self.client.execute(next_request)?; + response = self.next(&mut ctx, next_request)?; } Ok(response) } } -fn clone_request(request: &mut Request) -> Result { - if let Some(b) = request.body_mut().as_mut() { - b.buffer()?; - } - // This doesn't copy the contents of the buffer, cloning requests is cheap - // https://docs.rs/bytes/1.0.1/bytes/struct.Bytes.html - Ok(request.try_clone().unwrap()) // guaranteed to not fail if body is already buffered -} - // See https://github.com/seanmonstar/reqwest/blob/bbeb1ede4e8098481c3de6f2cafb8ecca1db4ede/src/async_impl/client.rs#L1500-L1607 fn get_next_request(mut request: Request, response: &Response) -> Option { let get_next_url = |request: &Request| { diff --git a/src/session.rs b/src/session.rs index 7dfe7e77..fb4e7a10 100644 --- a/src/session.rs +++ b/src/session.rs @@ -10,6 +10,7 @@ use reqwest::header::HeaderMap; use reqwest::Url; use serde::{Deserialize, Serialize}; +use crate::auth; use crate::utils::{config_dir, test_mode}; #[derive(Debug, Serialize, Deserialize)] @@ -115,36 +116,53 @@ impl Session { Ok(()) } - pub fn auth(&self) -> Result> { + pub fn auth(&self) -> Result> { if let Auth { - auth_type: Some(ref auth_type), - raw_auth: Some(ref raw_auth), - } = self.content.auth + auth_type: Some(auth_type), + raw_auth: Some(raw_auth), + } = &self.content.auth { - if auth_type.as_str() == "basic" { - return Ok(Some(format!("Basic {}", base64::encode(raw_auth)))); - } else if auth_type.as_str() == "bearer" { - return Ok(Some(format!("Bearer {}", raw_auth))); - } else { - return Err(anyhow!("Unknown auth type {}", raw_auth)); + match auth_type.as_str() { + "basic" => { + let (username, password) = auth::parse_auth(raw_auth, "")?; + Ok(Some(auth::Auth::Basic(username, password))) + } + "digest" => { + let (username, password) = auth::parse_auth(raw_auth, "")?; + Ok(Some(auth::Auth::Digest( + username, + password.unwrap_or_else(|| "".into()), + ))) + } + "bearer" => Ok(Some(auth::Auth::Bearer(raw_auth.into()))), + _ => Err(anyhow!("Unknown auth type {}", raw_auth)), } - } - - Ok(None) - } - - pub fn save_bearer_auth(&mut self, token: String) { - self.content.auth = Auth { - auth_type: Some("bearer".into()), - raw_auth: Some(token), + } else { + Ok(None) } } - pub fn save_basic_auth(&mut self, username: String, password: Option) { - let password = password.unwrap_or_else(|| "".into()); - self.content.auth = Auth { - auth_type: Some("basic".into()), - raw_auth: Some(format!("{}:{}", username, password)), + pub fn save_auth(&mut self, auth: &auth::Auth) { + match auth { + auth::Auth::Basic(username, password) => { + let password = password.as_deref().unwrap_or(""); + self.content.auth = Auth { + auth_type: Some("basic".into()), + raw_auth: Some(format!("{}:{}", username, password)), + } + } + auth::Auth::Digest(username, password) => { + self.content.auth = Auth { + auth_type: Some("digest".into()), + raw_auth: Some(format!("{}:{}", username, password)), + } + } + auth::Auth::Bearer(token) => { + self.content.auth = Auth { + auth_type: Some("bearer".into()), + raw_auth: Some(token.into()), + } + } } } diff --git a/src/to_curl.rs b/src/to_curl.rs index d8485fda..c3261b75 100644 --- a/src/to_curl.rs +++ b/src/to_curl.rs @@ -3,10 +3,8 @@ use std::io::{stderr, stdout, Write}; use anyhow::{anyhow, Result}; use reqwest::Method; -use crate::{ - cli::{Cli, HttpVersion, Verify}, - request_items::{Body, RequestItem, FORM_CONTENT_TYPE, JSON_ACCEPT, JSON_CONTENT_TYPE}, -}; +use crate::cli::{AuthType, Cli, HttpVersion, Verify}; +use crate::request_items::{Body, RequestItem, FORM_CONTENT_TYPE, JSON_ACCEPT, JSON_CONTENT_TYPE}; pub fn print_curl_translation(args: Cli) -> Result<()> { let cmd = translate(args)?; @@ -243,13 +241,24 @@ pub fn translate(args: Cli) -> Result { cmd.push(format!("{}:", header)); } if let Some(auth) = args.auth { - // curl implements this flag the same way, including password prompt - cmd.flag("-u", "--user"); - cmd.push(auth); - } - if let Some(token) = args.bearer { - cmd.push("--oauth2-bearer"); - cmd.push(token); + match args.auth_type.unwrap_or_default() { + AuthType::basic => { + cmd.push("--basic"); + // curl implements this flag the same way, including password prompt + cmd.flag("-u", "--user"); + cmd.push(auth); + } + AuthType::digest => { + cmd.push("--digest"); + // curl implements this flag the same way, including password prompt + cmd.flag("-u", "--user"); + cmd.push(auth); + } + AuthType::bearer => { + cmd.push("--oauth2-bearer"); + cmd.push(auth); + } + } } if args.request_items.is_multipart() { diff --git a/src/utils.rs b/src/utils.rs index 77b4e9fc..6ee3b56f 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -1,11 +1,20 @@ -use std::{ - env::var_os, - io::{self, Write}, - path::PathBuf, -}; +use std::env::var_os; +use std::io::{self, Write}; +use std::path::PathBuf; +use anyhow::Result; +use reqwest::blocking::Request; use url::{Host, Url}; +pub fn clone_request(request: &mut Request) -> Result { + if let Some(b) = request.body_mut().as_mut() { + b.buffer()?; + } + // This doesn't copy the contents of the buffer, cloning requests is cheap + // https://docs.rs/bytes/1.0.1/bytes/struct.Bytes.html + Ok(request.try_clone().unwrap()) // guaranteed to not fail if body is already buffered +} + /// Whether to make some things more deterministic for the benefit of tests pub fn test_mode() -> bool { // In integration tests the binary isn't compiled with cfg(test), so we @@ -41,6 +50,15 @@ pub fn config_dir() -> Option { } } +pub fn get_home_dir() -> Option { + #[cfg(target_os = "windows")] + if let Some(path) = std::env::var_os("XH_TEST_MODE_WIN_HOME_DIR") { + return Some(PathBuf::from(path)); + } + + dirs::home_dir() +} + // https://stackoverflow.com/a/45145246/5915221 #[macro_export] macro_rules! vec_of_strings { diff --git a/tests/cli.rs b/tests/cli.rs index c0323000..fcbdd28b 100644 --- a/tests/cli.rs +++ b/tests/cli.rs @@ -1908,6 +1908,53 @@ fn bearer_auth_from_session_is_used() { mock.assert(); } +#[test] +fn auth_netrc_is_not_persisted_in_session() { + let server = MockServer::start(); + let mock = server.mock(|when, _| { + when.header("Authorization", "Basic dXNlcjpwYXNz"); + }); + + let mut path_to_session = std::env::temp_dir(); + let file_name = random_string(); + path_to_session.push(file_name); + assert_eq!(path_to_session.exists(), false); + + let mut netrc = tempfile::NamedTempFile::new().unwrap(); + writeln!( + netrc, + "machine {}\nlogin user\npassword pass", + server.host() + ) + .unwrap(); + + get_command() + .env("NETRC", netrc.path()) + .arg(server.base_url()) + .arg("hello:world") + .arg(format!("--session={}", path_to_session.to_string_lossy())) + .assert() + .success(); + + mock.assert(); + + let session_content = read_to_string(path_to_session).unwrap(); + assert_eq!( + serde_json::from_str::(&session_content).unwrap(), + serde_json::json!({ + "__meta__": { + "about": "xh session file", + "xh": "0.0.0" + }, + "auth": { "type": null, "raw_auth": null }, + "cookies": {}, + "headers": { + "hello": "world" + } + }) + ); +} + #[test] fn print_intermediate_requests_and_responses() { let server1 = MockServer::start(); @@ -2197,6 +2244,142 @@ fn warns_if_config_is_invalid() { .success(); } +#[test] +fn digest_auth() { + let server1 = MockServer::start(); + let server2 = MockServer::start(); + let mock1 = server1.mock(|when, then| { + when.matches(|req: &HttpMockRequest| { + !req.headers + .as_ref() + .unwrap() + .iter() + .any(|(key, _)| key == "Authorization") + }); + then.status(401).header("WWW-Authenticate", r#"Digest realm="me@xh.com", nonce="e5051361f053723a807674177fc7022f", qop="auth, auth-int", opaque="9dcf562038f1ec1c8d02f218ef0e7a4b", algorithm=MD5, stale=FALSE"#); + }); + let mock2 = server2.mock(|when, then| { + when.header_exists("Authorization"); + then.body("authenticated"); + }); + + get_command() + .env("XH_TEST_DIGEST_AUTH_URL", server2.base_url()) + .arg("--auth-type=digest") + .arg("--auth=ahmed:12345") + .arg(server1.base_url()) + .assert() + .stdout(contains("HTTP/1.1 200 OK")); + + mock1.assert(); + mock2.assert(); +} + +#[test] +fn successful_digest_auth() { + get_command() + .arg("--auth-type=digest") + .arg("--auth=ahmed:12345") + .arg("httpbin.org/digest-auth/5/ahmed/12345") + .assert() + .stdout(contains("HTTP/1.1 200 OK")); +} + +#[test] +fn unsuccessful_digest_auth() { + get_command() + .arg("--auth-type=digest") + .arg("--auth=ahmed:wrongpass") + .arg("httpbin.org/digest-auth/5/ahmed/12345") + .assert() + .stdout(contains("HTTP/1.1 401 Unauthorized")); +} + +#[test] +fn digest_auth_with_redirection() { + let server1 = MockServer::start(); + let server2 = MockServer::start(); + let server3 = MockServer::start(); + let mock1 = server1.mock(|when, then| { + when.matches(|req: &HttpMockRequest| { + !req.headers + .as_ref() + .unwrap() + .iter() + .any(|(key, _)| key == "Authorization") + }); + then.status(401) + .header("WWW-Authenticate", r#"Digest realm="me@xh.com", nonce="e5051361f053723a807674177fc7022f", qop="auth, auth-int", opaque="9dcf562038f1ec1c8d02f218ef0e7a4b", algorithm=MD5, stale=FALSE"#) + .header("date", "N/A"); + }); + let mock2 = server2.mock(|when, then| { + when.header_exists("Authorization"); + then.status(302) + .header("location", &server3.base_url()) + .header("date", "N/A") + .body("authentication successful, redirecting..."); + }); + server3.mock(|_, then| { + then.header("date", "N/A").body("final destination"); + }); + + get_command() + .env("XH_TEST_DIGEST_AUTH_URL", server2.base_url()) + .env("XH_TEST_DIGEST_AUTH_CNONCE", "f2/wE4q74E6zIJEtWaHKaf5wv/H5QzzpXusqGemxURZJ") + .arg("--auth-type=digest") + .arg("--auth=ahmed:12345") + .arg("--follow") + .arg("--verbose") + .arg(server1.base_url()) + .assert() + .stdout(formatdoc! {r#" + GET / HTTP/1.1 + Accept: */* + Accept-Encoding: gzip, deflate, br + Connection: keep-alive + Host: http.mock + User-Agent: xh/0.0.0 (test mode) + + HTTP/1.1 401 Unauthorized + Content-Length: 0 + Date: N/A + Www-Authenticate: Digest realm="me@xh.com", nonce="e5051361f053723a807674177fc7022f", qop="auth, auth-int", opaque="9dcf562038f1ec1c8d02f218ef0e7a4b", algorithm=MD5, stale=FALSE + + + + GET / HTTP/1.1 + Accept: */* + Accept-Encoding: gzip, deflate, br + Authorization: Digest username="ahmed", realm="me@xh.com", nonce="e5051361f053723a807674177fc7022f", uri="/", qop=auth, nc=00000001, cnonce="f2/wE4q74E6zIJEtWaHKaf5wv/H5QzzpXusqGemxURZJ", response="1e96c9808de24d5dd36e9e4865ffca7d", opaque="9dcf562038f1ec1c8d02f218ef0e7a4b", algorithm=MD5 + Connection: keep-alive + Host: http.mock + User-Agent: xh/0.0.0 (test mode) + + HTTP/1.1 302 Found + Content-Length: 41 + Date: N/A + Location: {redirect_url} + + authentication successful, redirecting... + + GET / HTTP/1.1 + Accept: */* + Accept-Encoding: gzip, deflate, br + Connection: keep-alive + Host: http.mock + User-Agent: xh/0.0.0 (test mode) + + HTTP/1.1 200 OK + Content-Length: 17 + Date: N/A + + final destination + "#, redirect_url = server3.base_url()}); + + mock1.assert(); + mock2.assert(); +} + #[test] fn http1_0() { get_command()