diff --git a/src/agent.rs b/src/agent.rs index c39e0e27..6f89fa11 100644 --- a/src/agent.rs +++ b/src/agent.rs @@ -8,14 +8,17 @@ //! Since request executions are driven through futures, the agent also acts as //! a specialized task executor for tasks related to requests. -use crate::handler::RequestHandler; -use crate::task::{UdpWaker, WakerExt}; -use crate::Error; +use crate::{ + error::Error, + handler::RequestHandler, + task::{UdpWaker, WakerExt}, +}; use crossbeam_utils::sync::WaitGroup; use curl::multi::WaitFd; use flume::{Receiver, Sender}; use slab::Slab; use std::{ + io, net::UdpSocket, sync::Mutex, task::Waker, @@ -54,7 +57,7 @@ impl AgentBuilder { /// Spawn a new agent using the configuration in this builder and return a /// handle for communicating with the agent. - pub(crate) fn spawn(&self) -> Result { + pub(crate) fn spawn(&self) -> io::Result { let create_start = Instant::now(); // Initialize libcurl, if necessary, on the current thread. diff --git a/src/body.rs b/src/body.rs index ceed168e..18e7aa15 100644 --- a/src/body.rs +++ b/src/body.rs @@ -10,28 +10,6 @@ use std::{ task::{Context, Poll}, }; -macro_rules! match_type { - { - $( - <$name:ident as $T:ty> => $branch:expr, - )* - $defaultName:ident => $defaultBranch:expr, - } => {{ - match () { - $( - _ if ::std::any::Any::type_id(&$name) == ::std::any::TypeId::of::<$T>() => { - #[allow(unsafe_code)] - let $name: $T = unsafe { - ::std::mem::transmute_copy::<_, $T>(&::std::mem::ManuallyDrop::new($name)) - }; - $branch - } - )* - _ => $defaultBranch, - } - }}; -} - /// Contains the body of an HTTP request or response. /// /// This type is used to encapsulate the underlying stream or region of memory diff --git a/src/client.rs b/src/client.rs index a1c61091..312685e9 100644 --- a/src/client.rs +++ b/src/client.rs @@ -3,13 +3,14 @@ use crate::{ agent::{self, AgentBuilder}, auth::{Authentication, Credentials}, + body::Body, config::internal::{ConfigurableBase, SetOpt}, config::*, default_headers::DefaultHeadersInterceptor, + error::{Error, ErrorKind}, handler::{RequestHandler, ResponseBodyReader}, headers, interceptor::{self, Interceptor, InterceptorObj}, - Body, Error, }; use futures_lite::{future::block_on, io::AsyncRead, pin}; use http::{ @@ -343,11 +344,11 @@ impl HttpClientBuilder { self.default_headers.append(key, value); } Err(e) => { - self.error = Some(e.into().into()); + self.error = Some(Error::new(ErrorKind::ClientInitialization, e.into())); } }, Err(e) => { - self.error = Some(e.into().into()); + self.error = Some(Error::new(ErrorKind::ClientInitialization, e.into())); } } self @@ -450,20 +451,22 @@ impl HttpClientBuilder { #[cfg(not(feature = "cookies"))] let inner = Inner { - agent: self.agent_builder.spawn()?, + agent: self.agent_builder.spawn().map_err(|e| Error::new(ErrorKind::ClientInitialization, e))?, defaults: self.defaults, interceptors: self.interceptors, }; #[cfg(feature = "cookies")] let inner = Inner { - agent: self.agent_builder.spawn()?, + agent: self.agent_builder.spawn().map_err(|e| Error::new(ErrorKind::ClientInitialization, e))?, defaults: self.defaults, interceptors: self.interceptors, cookie_jar: self.cookie_jar, }; - Ok(HttpClient { inner: Arc::new(inner) }) + Ok(HttpClient { + inner: Arc::new(inner), + }) } } @@ -896,7 +899,10 @@ impl HttpClient { builder: http::request::Builder, body: Body, ) -> ResponseFuture<'_> { - ResponseFuture::new(async move { self.send_async_inner(builder.body(body)?).await }) + ResponseFuture::new(async move { + self.send_async_inner(builder.body(body).map_err(Error::from_any)?) + .await + }) } /// Actually send the request. All the public methods go through here. diff --git a/src/config/mod.rs b/src/config/mod.rs index ed2a2349..a2cb17e6 100644 --- a/src/config/mod.rs +++ b/src/config/mod.rs @@ -170,7 +170,8 @@ pub trait Configurable: internal::ConfigurableBase { /// decode the HTTP response body for known and available compression /// algorithms. If the server returns a response with an unknown or /// unavailable encoding, Isahc will return an - /// [`InvalidContentEncoding`](crate::Error::InvalidContentEncoding) error. + /// [`InvalidContentEncoding`](crate::error::ErrorKind::InvalidContentEncoding) + /// error. /// /// If you do not specify a specific value for the /// [`Accept-Encoding`](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Accept-Encoding) diff --git a/src/error.rs b/src/error.rs index b6580f24..07cc8f71 100644 --- a/src/error.rs +++ b/src/error.rs @@ -1,187 +1,404 @@ //! Types for error handling. -#![allow(deprecated)] +use std::{error::Error as StdError, fmt, io, sync::Arc}; -use std::error::Error as StdError; -use std::fmt; -use std::io; - -/// All possible types of errors that can be returned from Isahc. -#[derive(Debug)] -pub enum Error { - /// The request was aborted before it could be completed. - Aborted, +/// A non-exhaustive list of error types that can occur while sending an HTTP +/// request or receiving an HTTP response. +/// +/// These are meant to be treated as general error codes that allow you to +/// handle different sorts of errors in different ways, but are not always +/// specific. The list is also non-exhaustive, and more variants may be added in +/// the future. +#[derive(Clone, Debug, Eq, PartialEq)] +#[non_exhaustive] +pub enum ErrorKind { /// A problem occurred with the local certificate. - BadClientCertificate(Option), + BadClientCertificate, + /// The server certificate could not be validated. - BadServerCertificate(Option), + BadServerCertificate, + + /// The HTTP client failed to initialize. + ClientInitialization, + /// Failed to connect to the server. - ConnectFailed, - /// Couldn't resolve host name. - CouldntResolveHost, - /// Couldn't resolve proxy host name. - CouldntResolveProxy, - /// An unrecognized error thrown by curl. - Curl(String), - /// Unrecognized or bad content encoding returned by the server. - InvalidContentEncoding(Option), - /// Provided credentials were rejected by the server. + ConnectionFailed, + + /// The server either returned a response using an unknown or unsupported + /// encoding format, or the response encoding was malformed. + InvalidContentEncoding, + + /// Provided authentication credentials were rejected by the server. + /// + /// This error is only returned when using Isahc's built-in authentication + /// methods. If using authentication headers manually, the server's response + /// will be returned as a success unaltered. InvalidCredentials, - /// Validation error when constructing the request or parsing the response. - InvalidHttpFormat(http::Error), - /// Invalid UTF-8 string error. - InvalidUtf8, - /// An unknown I/O error. - Io(io::Error), - /// The server did not send a response. - NoResponse, - /// The server does not support or accept range requests. - RangeRequestUnsupported, - /// An error occurred while writing the request body. - RequestBodyError(Option), - /// An error occurred while reading the response body. - ResponseBodyError(Option), - /// Failed to connect over a secure socket. - SSLConnectFailed(Option), - /// An error ocurred in the secure socket engine. - SSLEngineError(Option), - /// An ongoing request took longer than the configured timeout time. + + /// The request to be sent was invalid and could not be sent. + /// + /// Note that this is only returned for requests that the client deemed + /// invalid. If the request appears to be valid but is rejected by the + /// server, then the server's response will likely indicate as such. + InvalidRequest, + + /// An I/O error either sending the request or reading the response. This + /// could be caused by a problem on the client machine, a problem on the + /// server machine, or a problem with the network between the two. + Io, + + /// Failed to resolve a host name. + /// + /// This could be caused by any number of problems, including failure to + /// reach a DNS server, misconfigured resolver configuration, or the + /// hostname simply does not exist. + NameResolution, + + /// The server made an unrecoverable HTTP protocol violation. This indicates + /// a bug in the server. Retrying a request that returns this error is + /// likely to produce the same error. + ProtocolViolation, + + /// Request processing could not continue because the client needed to + /// re-send the request body, but was unable to rewind the body stream to + /// the beginning in order to do so. + RequestBodyNotRewindable, + + /// A request or operation took longer than the configured timeout time. Timeout, + + /// An error ocurred in the secure socket engine. + TlsEngine, + /// Number of redirects hit the maximum amount. TooManyRedirects, + + /// An unknown error occurred. This likely indicates a problem in the HTTP + /// client or in a dependency, but the client was able to recover instead of + /// panicking. Subsequent requests will likely succeed. + /// + /// Only used internally. + #[doc(hidden)] + Unknown, } -impl fmt::Display for Error { +impl ErrorKind { + #[inline] + fn description(&self) -> Option<&str> { + match self { + Self::BadClientCertificate => Some("a problem occurred with the local certificate"), + Self::BadServerCertificate => Some("the server certificate could not be validated"), + Self::ClientInitialization => Some("failed to initialize client"), + Self::ConnectionFailed => Some("failed to connect to the server"), + Self::InvalidContentEncoding => Some("the server either returned a response using an unknown or unsupported encoding format, or the response encoding was malformed"), + Self::InvalidCredentials => Some("provided authentication credentials were rejected by the server"), + Self::InvalidRequest => Some("invalid HTTP request"), + Self::NameResolution => Some("failed to resolve host name"), + Self::ProtocolViolation => Some("the server made an unrecoverable HTTP protocol violation"), + Self::RequestBodyNotRewindable => Some("request body could not be re-sent because it is not rewindable"), + Self::Timeout => Some("request or operation took longer than the configured timeout time"), + Self::TlsEngine => Some("error ocurred in the secure socket engine"), + Self::TooManyRedirects => Some("number of redirects hit the maximum amount"), + _ => None, + } + } +} + +impl fmt::Display for ErrorKind { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "{:?}: {}", self, Error::description(self)) + f.write_str(self.description().unwrap_or("unknown error")) } } -impl StdError for Error { - fn description(&self) -> &str { - match self { - Error::Aborted => "request aborted unexpectedly", - Error::BadClientCertificate(Some(ref e)) => e, - Error::BadServerCertificate(Some(ref e)) => e, - Error::ConnectFailed => "failed to connect to the server", - Error::CouldntResolveHost => "couldn't resolve host name", - Error::CouldntResolveProxy => "couldn't resolve proxy host name", - Error::Curl(ref e) => e, - Error::InvalidContentEncoding(Some(ref e)) => e, - Error::InvalidCredentials => "credentials were rejected by the server", - Error::InvalidHttpFormat(ref e) => e.description(), - Error::InvalidUtf8 => "bytes are not valid UTF-8", - Error::Io(ref e) => e.description(), - Error::NoResponse => "server did not send a response", - Error::RangeRequestUnsupported => "server does not support or accept range requests", - Error::RequestBodyError(Some(ref e)) => e, - Error::ResponseBodyError(Some(ref e)) => e, - Error::SSLConnectFailed(Some(ref e)) => e, - Error::SSLEngineError(Some(ref e)) => e, - Error::Timeout => "request took longer than the configured timeout", - Error::TooManyRedirects => "max redirect limit exceeded", - _ => "unknown error", +// Improve equality ergonomics for references. +impl PartialEq for &'_ ErrorKind { + fn eq(&self, other: &ErrorKind) -> bool { + *self == other + } +} + +/// An error encountered while sending an HTTP request or receiving an HTTP +/// response. +/// +/// This type is intentionally opaque, as sending an HTTP request involves many +/// different moving parts, some of which can be platform or device-dependent. +/// It is recommended that you use the [`kind`][Error::kind] method to get a +/// more generalized classification of error types that this error could be if +/// you need to handle different sorts of errors in different ways. +/// +/// If you need to get more specific details about the reason for the error, you +/// can use the [`source`][std::error::Error::source] method. We do not provide +/// any stability guarantees about what error sources are returned. +#[derive(Clone)] +pub struct Error(Arc); + +struct Inner { + kind: ErrorKind, + context: Option, + source: Option>, +} + +impl Error { + /// Create a new error from a given error kind and source error. + pub(crate) fn new(kind: ErrorKind, source: E) -> Self + where + E: StdError + Send + Sync + 'static, + { + Self::with_context(kind, None, source) + } + + /// Create a new error from a given error kind, source error, and context + /// string. + pub(crate) fn with_context(kind: ErrorKind, context: Option, source: E) -> Self + where + E: StdError + Send + Sync + 'static, + { + Self(Arc::new(Inner { + kind, + context, + source: Some(Box::new(source)), + })) + } + + /// Statically cast a given error into an Isahc error, converting if + /// necessary. + pub(crate) fn from_any(error: E) -> Self + where + E: StdError + Send + Sync + 'static, + { + match_type! { + => error, + => error.into(), + error => Error::new(ErrorKind::Unknown, error), } } - fn cause(&self) -> Option<&dyn StdError> { - match self { - Error::InvalidHttpFormat(e) => Some(e), - Error::Io(e) => Some(e), - _ => None, + /// Get the kind of error this represents. + /// + /// The kind returned may not be matchable against any known documented if + /// the reason for the error is unknown. Unknown errors may be an indication + /// of a bug, or an error condition that we do not recognize appropriately. + /// Either way, please report such occurrences to us! + #[inline] + pub fn kind(&self) -> &ErrorKind { + &self.0.kind + } + + /// Returns true if this error was likely caused by the client. + /// + /// Usually indicates that the client was misconfigured or used to send + /// invalid data to the server. Requests that return these sorts of errors + /// probably should not be retried without first fixing the request + /// parameters. + pub fn is_client(&self) -> bool { + match self.kind() { + ErrorKind::BadClientCertificate + | ErrorKind::ClientInitialization + | ErrorKind::InvalidCredentials + | ErrorKind::InvalidRequest + | ErrorKind::RequestBodyNotRewindable + | ErrorKind::TlsEngine => true, + _ => false, } } -} -#[doc(hidden)] -impl From for Error { - fn from(error: curl::Error) -> Error { - if error.is_ssl_certproblem() || error.is_ssl_cacert_badfile() { - Error::BadClientCertificate(error.extra_description().map(str::to_owned)) - } else if error.is_peer_failed_verification() || error.is_ssl_cacert() { - Error::BadServerCertificate(error.extra_description().map(str::to_owned)) - } else if error.is_couldnt_connect() { - Error::ConnectFailed - } else if error.is_couldnt_resolve_host() { - Error::CouldntResolveHost - } else if error.is_couldnt_resolve_proxy() { - Error::CouldntResolveProxy - } else if error.is_bad_content_encoding() || error.is_conv_failed() { - Error::InvalidContentEncoding(error.extra_description().map(str::to_owned)) - } else if error.is_login_denied() { - Error::InvalidCredentials - } else if error.is_got_nothing() { - Error::NoResponse - } else if error.is_range_error() { - Error::RangeRequestUnsupported - } else if error.is_read_error() || error.is_aborted_by_callback() { - Error::RequestBodyError(error.extra_description().map(str::to_owned)) - } else if error.is_write_error() || error.is_partial_file() { - Error::ResponseBodyError(error.extra_description().map(str::to_owned)) - } else if error.is_ssl_connect_error() { - Error::SSLConnectFailed(error.extra_description().map(str::to_owned)) - } else if error.is_ssl_engine_initfailed() - || error.is_ssl_engine_notfound() - || error.is_ssl_engine_setfailed() - { - Error::SSLEngineError(error.extra_description().map(str::to_owned)) - } else if error.is_operation_timedout() { - Error::Timeout - } else if error.is_too_many_redirects() { - Error::TooManyRedirects - } else { - Error::Curl(error.description().to_owned()) + /// Returns true if this is an error likely related to network failures. + pub fn is_network(&self) -> bool { + match self.kind() { + ErrorKind::ConnectionFailed | ErrorKind::Io | ErrorKind::NameResolution => true, + _ => false, + } + } + + /// Returns true if this error was likely the fault of the server. + pub fn is_server(&self) -> bool { + match self.kind() { + ErrorKind::BadServerCertificate | ErrorKind::ProtocolViolation | ErrorKind::TooManyRedirects => { + true + } + _ => false, + } + } + + /// Returns true if this error is related to SSL/TLS. + pub fn is_tls(&self) -> bool { + match self.kind() { + ErrorKind::BadClientCertificate + | ErrorKind::BadServerCertificate + | ErrorKind::TlsEngine => true, + _ => false, } } } -#[doc(hidden)] -impl From for Error { - fn from(error: curl::MultiError) -> Error { - Error::Curl(error.description().to_owned()) +impl StdError for Error { + fn source(&self) -> Option<&(dyn StdError + 'static)> { + self.0.source.as_ref().map(|source| &**source as _) } } -#[doc(hidden)] -impl From for Error { - fn from(error: http::Error) -> Error { - Error::InvalidHttpFormat(error) +impl PartialEq for Error { + fn eq(&self, kind: &ErrorKind) -> bool { + self.kind().eq(kind) + } +} + +impl fmt::Debug for Error { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("Error") + .field("kind", &self.kind()) + .field("context", &self.0.context) + .field("source", &self.source()) + .finish() + } +} + +impl fmt::Display for Error { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + if let Some(s) = self.0.context.as_ref() { + write!(f, "{}: {}", self.kind(), s) + } else { + write!(f, "{}", self.kind()) + } + } +} + +impl From for Error { + fn from(kind: ErrorKind) -> Self { + Self(Arc::new(Inner { + kind, + context: None, + source: None, + })) } } -#[doc(hidden)] impl From for Error { - fn from(error: io::Error) -> Error { - match error.kind() { - io::ErrorKind::ConnectionRefused => Error::ConnectFailed, - io::ErrorKind::TimedOut => Error::Timeout, - _ => Error::Io(error), + fn from(error: io::Error) -> Self { + // If this I/O error is just a wrapped Isahc error, then unwrap it. + if let Some(inner) = error.get_ref() { + if inner.is::() { + return *error.into_inner().unwrap().downcast().unwrap(); + } } + + Self::new( + match error.kind() { + io::ErrorKind::ConnectionRefused => ErrorKind::ConnectionFailed, + io::ErrorKind::TimedOut => ErrorKind::Timeout, + _ => ErrorKind::Io, + }, + error, + ) } } -#[doc(hidden)] impl From for io::Error { - fn from(error: Error) -> io::Error { - match error { - Error::ConnectFailed => io::ErrorKind::ConnectionRefused.into(), - Error::Io(e) => e, - Error::Timeout => io::ErrorKind::TimedOut.into(), - e => io::Error::new(io::ErrorKind::Other, e), - } + fn from(error: Error) -> Self { + let kind = match error.kind() { + ErrorKind::ConnectionFailed => io::ErrorKind::ConnectionRefused, + ErrorKind::Timeout => io::ErrorKind::TimedOut, + _ => io::ErrorKind::Other, + }; + + Self::new(kind, error) + } +} + +impl From for Error { + fn from(error: http::Error) -> Error { + Self::new( + if error.is::() + || error.is::() + || error.is::() + || error.is::() + || error.is::() + { + ErrorKind::InvalidRequest + } else { + ErrorKind::Unknown + }, + error, + ) } } #[doc(hidden)] -impl From for Error { - fn from(_: std::string::FromUtf8Error) -> Error { - Error::InvalidUtf8 +impl From for Error { + fn from(error: curl::Error) -> Error { + Self::with_context( + if error.is_ssl_certproblem() || error.is_ssl_cacert_badfile() { + ErrorKind::BadClientCertificate + } else if error.is_peer_failed_verification() + || error.is_ssl_cacert() + || error.is_ssl_cipher() + || error.is_ssl_issuer_error() + { + ErrorKind::BadServerCertificate + } else if error.is_interface_failed() { + ErrorKind::ClientInitialization + } else if error.is_couldnt_connect() || error.is_ssl_connect_error() { + ErrorKind::ConnectionFailed + } else if error.is_bad_content_encoding() || error.is_conv_failed() { + ErrorKind::InvalidContentEncoding + } else if error.is_login_denied() { + ErrorKind::InvalidCredentials + } else if error.is_url_malformed() { + ErrorKind::InvalidRequest + } else if error.is_couldnt_resolve_host() || error.is_couldnt_resolve_proxy() { + ErrorKind::NameResolution + } else if error.is_got_nothing() + || error.is_http2_error() + || error.is_http2_stream_error() + || error.is_unsupported_protocol() + || error.code() == curl_sys::CURLE_FTP_WEIRD_SERVER_REPLY + { + ErrorKind::ProtocolViolation + } else if error.is_send_error() + || error.is_recv_error() + || error.is_read_error() + || error.is_write_error() + || error.is_upload_failed() + || error.is_send_fail_rewind() + || error.is_aborted_by_callback() + || error.is_partial_file() + { + ErrorKind::Io + } else if error.is_ssl_engine_initfailed() + || error.is_ssl_engine_notfound() + || error.is_ssl_engine_setfailed() + { + ErrorKind::TlsEngine + } else if error.is_operation_timedout() { + ErrorKind::Timeout + } else if error.is_too_many_redirects() { + ErrorKind::TooManyRedirects + } else { + ErrorKind::Unknown + }, + error.extra_description().map(String::from), + error, + ) } } #[doc(hidden)] -impl From for Error { - fn from(_: std::str::Utf8Error) -> Error { - Error::InvalidUtf8 +impl From for Error { + fn from(error: curl::MultiError) -> Error { + Self::new( + if error.is_bad_socket() { + ErrorKind::Io + } else { + ErrorKind::Unknown + }, + error, + ) } } + +#[cfg(test)] +mod tests { + use super::*; + + static_assertions::assert_impl_all!(Error: Send, Sync); +} diff --git a/src/handler.rs b/src/handler.rs index b9784e24..c719622d 100644 --- a/src/handler.rs +++ b/src/handler.rs @@ -146,14 +146,14 @@ impl RequestHandler { // Create a future that resolves when the handler receives the response // headers. let future = async move { - let builder = receiver.recv_async().await.map_err(|_| Error::Aborted)??; + let builder = receiver.recv_async().await.map_err(|e| Error::new(crate::error::ErrorKind::Unknown, e))??; let reader = ResponseBodyReader { inner: response_body_reader, shared, }; - builder.body(reader).map_err(Error::InvalidHttpFormat) + builder.body(reader).map_err(|e| Error::new(crate::error::ErrorKind::ProtocolViolation, e)) }; (handler, future) diff --git a/src/interceptor/context.rs b/src/interceptor/context.rs index 3eb3174f..e6ac28cf 100644 --- a/src/interceptor/context.rs +++ b/src/interceptor/context.rs @@ -22,18 +22,7 @@ impl<'a> Context<'a> { interceptors: &self.interceptors[1..], }; - match interceptor.intercept(request, inner_context).await { - Ok(response) => Ok(response), - - // If the error is an Isahc error, return it directly. - Err(e) => match e.downcast::() { - Ok(e) => Err(*e), - - // TODO: Introduce a new error variant for errors caused by an - // interceptor. This is a temporary hack. - Err(e) => Err(Error::Curl(e.to_string())), - }, - } + interceptor.intercept(request, inner_context).await } else { self.invoker.invoke(request).await } diff --git a/src/interceptor/mod.rs b/src/interceptor/mod.rs index e056891e..f266ec42 100644 --- a/src/interceptor/mod.rs +++ b/src/interceptor/mod.rs @@ -55,7 +55,7 @@ macro_rules! interceptor { async fn interceptor( mut $request: $crate::http::Request<$crate::Body>, $ctx: $crate::interceptor::Context<'_>, - ) -> Result<$crate::http::Response, Box> { + ) -> Result<$crate::http::Response, $crate::Error> { (move || async move { $body })().await.map_err(Into::into) @@ -72,7 +72,7 @@ macro_rules! interceptor { /// made in parallel. pub trait Interceptor: Send + Sync { /// The type of error returned by this interceptor. - type Err: Into>; + type Err: Error + Send + Sync + 'static; /// Intercept a request, returning a response. /// @@ -88,7 +88,7 @@ pub type InterceptorFuture<'a, E> = Pin(f: F) -> InterceptorFn where F: for<'a> private::AsyncFn2, Context<'a>, Output = Result, E>> + Send + Sync + 'static, - E: Into>, + E: Error + Send + Sync + 'static, { InterceptorFn(f) } @@ -99,7 +99,7 @@ pub struct InterceptorFn(F); impl Interceptor for InterceptorFn where - E: Into>, + E: Error + Send + Sync + 'static, F: for<'a> private::AsyncFn2, Context<'a>, Output = Result, E>> + Send + Sync + 'static, { type Err = E; diff --git a/src/interceptor/obj.rs b/src/interceptor/obj.rs index 594d13d1..98a03842 100644 --- a/src/interceptor/obj.rs +++ b/src/interceptor/obj.rs @@ -1,7 +1,6 @@ -use crate::Body; +use crate::{Body, Error}; use super::{Context, Interceptor, InterceptorFuture}; use http::Request; -use std::error::Error; /// Type-erased interceptor object. pub(crate) struct InterceptorObj(Box); @@ -13,7 +12,7 @@ impl InterceptorObj { } impl Interceptor for InterceptorObj { - type Err = Box; + type Err = Error; fn intercept<'a>(&'a self, request: Request, cx: Context<'a>) -> InterceptorFuture<'a, Self::Err> { self.0.dyn_intercept(request, cx) @@ -23,13 +22,13 @@ impl Interceptor for InterceptorObj { /// Object-safe version of the interceptor used for type erasure. Implementation /// detail of [`InterceptorObj`]. trait DynInterceptor: Send + Sync { - fn dyn_intercept<'a>(&'a self, request: Request, cx: Context<'a>) -> InterceptorFuture<'a, Box>; + fn dyn_intercept<'a>(&'a self, request: Request, cx: Context<'a>) -> InterceptorFuture<'a, Error>; } impl DynInterceptor for I { - fn dyn_intercept<'a>(&'a self, request: Request, cx: Context<'a>) -> InterceptorFuture<'a, Box> { + fn dyn_intercept<'a>(&'a self, request: Request, cx: Context<'a>) -> InterceptorFuture<'a, Error> { Box::pin(async move { - self.intercept(request, cx).await.map_err(Into::into) + self.intercept(request, cx).await.map_err(Error::from_any) }) } } diff --git a/src/lib.rs b/src/lib.rs index cb1ed448..01703d38 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -131,7 +131,7 @@ //! //! ```toml //! [dependencies.isahc] -//! version = "0.8" +//! version = "0.9" //! features = ["psl"] //! ``` //! @@ -224,6 +224,9 @@ use http::{Request, Response}; use once_cell::sync::Lazy; use std::convert::TryFrom; +#[macro_use] +mod macros; + #[cfg(feature = "cookies")] pub mod cookies; @@ -231,7 +234,6 @@ mod agent; mod body; mod client; mod default_headers; -mod error; mod handler; mod headers; mod metrics; @@ -242,6 +244,7 @@ mod task; mod text; pub mod auth; +pub mod error; pub mod config; #[cfg(feature = "unstable-interceptors")] diff --git a/src/macros.rs b/src/macros.rs new file mode 100644 index 00000000..1ecb2159 --- /dev/null +++ b/src/macros.rs @@ -0,0 +1,23 @@ +/// Helper macro that allows you to attempt to downcast a generic type, as long +/// as it is known to be `'static`. +macro_rules! match_type { + { + $( + <$name:ident as $T:ty> => $branch:expr, + )* + $defaultName:ident => $defaultBranch:expr, + } => {{ + match () { + $( + _ if ::std::any::Any::type_id(&$name) == ::std::any::TypeId::of::<$T>() => { + #[allow(unsafe_code)] + let $name: $T = unsafe { + ::std::mem::transmute_copy::<_, $T>(&::std::mem::ManuallyDrop::new($name)) + }; + $branch + } + )* + _ => $defaultBranch, + } + }}; +} diff --git a/src/redirect.rs b/src/redirect.rs index f50f08b8..5e418f6f 100644 --- a/src/redirect.rs +++ b/src/redirect.rs @@ -1,9 +1,10 @@ use crate::{ config::RedirectPolicy, + error::{Error, ErrorKind}, handler::RequestBody, interceptor::{Context, Interceptor, InterceptorFuture}, request::RequestExt, - Body, Error, + Body, }; use http::{Request, Response, Uri}; use std::convert::TryFrom; @@ -74,7 +75,7 @@ impl Interceptor for RedirectInterceptor { if let Some(location) = get_redirect_location(&effective_uri, &response) { // If we've reached the limit, return an error as requested. if redirect_count >= limit { - return Err(Error::TooManyRedirects); + return Err(ErrorKind::TooManyRedirects.into()); } // Set referer header. @@ -108,14 +109,14 @@ impl Interceptor for RedirectInterceptor { // There's not really a good way of handling this gracefully, so // we just return an error so that the user knows about it. if !request_body.reset() { - return Err(Error::RequestBodyError(Some(String::from( - "could not follow redirect because request body is not rewindable", - )))); + return Err(ErrorKind::RequestBodyNotRewindable.into()); } // Update the request to point to the new URI. effective_uri = location.clone(); - request = request_builder.uri(location).body(request_body)?; + request = request_builder.uri(location) + .body(request_body) + .map_err(|e| Error::new(ErrorKind::InvalidRequest, e))?; redirect_count += 1; } diff --git a/src/task.rs b/src/task.rs index 3756a9ec..320b6e31 100644 --- a/src/task.rs +++ b/src/task.rs @@ -1,7 +1,7 @@ //! Helpers for working with tasks and futures. -use crate::Error; use std::{ + io, net::{SocketAddr, UdpSocket}, task::Waker, }; @@ -31,7 +31,7 @@ pub(crate) struct UdpWaker { impl UdpWaker { /// Create a waker by connecting to the wake address of an UDP server. - pub(crate) fn connect(addr: SocketAddr) -> Result { + pub(crate) fn connect(addr: SocketAddr) -> io::Result { let socket = UdpSocket::bind("127.0.0.1:0")?; socket.connect(addr)?; diff --git a/tests/encoding.rs b/tests/encoding.rs index 274c1e74..fa7ab7b4 100644 --- a/tests/encoding.rs +++ b/tests/encoding.rs @@ -126,7 +126,7 @@ fn unknown_content_encoding_returns_error() { .send(); match result { - Err(isahc::Error::InvalidContentEncoding(_)) => {} + Err(e) if e.kind() == isahc::error::ErrorKind::InvalidContentEncoding => {} _ => panic!("expected unknown encoding error, instead got {:?}", result), }; diff --git a/tests/redirects.rs b/tests/redirects.rs index 9f05acff..81b81d6f 100644 --- a/tests/redirects.rs +++ b/tests/redirects.rs @@ -225,7 +225,7 @@ fn redirect_non_rewindable_body_returns_error() { .unwrap() .send(); - assert_matches!(result, Err(isahc::Error::RequestBodyError(_))); + assert_matches!(result, Err(e) if e == isahc::error::ErrorKind::RequestBodyNotRewindable); assert_eq!(m1.request().method, "POST"); } @@ -245,10 +245,7 @@ fn redirect_limit_is_respected() { .send(); // Request should error with too many redirects. - assert!(match result { - Err(isahc::Error::TooManyRedirects) => true, - _ => false, - }); + assert_matches!(result, Err(e) if e == isahc::error::ErrorKind::TooManyRedirects); // After request (limit + 1) that returns a redirect should error. assert_eq!(m.requests().len(), 6); diff --git a/tests/timeouts.rs b/tests/timeouts.rs index c84948f6..d08db6c0 100644 --- a/tests/timeouts.rs +++ b/tests/timeouts.rs @@ -21,7 +21,7 @@ fn request_errors_if_read_timeout_is_reached() { .send(); // Client should time-out. - assert_matches!(result, Err(isahc::Error::Timeout)); + assert_matches!(result, Err(e) if e == isahc::error::ErrorKind::Timeout); assert_eq!(m.requests().len(), 1); } diff --git a/tests/utils/mod.rs b/tests/utils/mod.rs index 767b9ab9..33440ff1 100644 --- a/tests/utils/mod.rs +++ b/tests/utils/mod.rs @@ -1,13 +1,13 @@ macro_rules! assert_matches { - ($value:expr, $pattern:pat) => {{ + ($value:expr, $($pattern:tt)+) => {{ match $value { - $pattern => {}, + $($pattern)* => {}, value => panic!( "assertion failed: `{}` matches `{}`\n value: `{:?}`", stringify!($value), - stringify!($pattern), + stringify!($($pattern)*), value, ), } - }} + }}; }