Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Port from reqwest to ureq #205

Merged
merged 5 commits into from
Jun 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ geo-types = { version = "0.7.10", optional = true }
libc = "0.2.119"
num-traits = "0.2.14"
thiserror = "1.0.30"
reqwest = { version = "0.12.0", optional = true, default-features = false, features = ["blocking", "rustls-tls"] }
ureq = { version = "2.0.0", optional = true }
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did you actually test this on 2.0.0? Otherwise please specify whatever version you tested on (latest is 2.9.7).

Otherwise it's hard to know if you are leveraging any functionality that was added in a minor release.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I tested it on 2.0.0.


[workspace]
members = ["proj-sys"]
Expand All @@ -28,7 +28,7 @@ members = ["proj-sys"]
default = ["geo-types"]
bundled_proj = [ "proj-sys/bundled_proj" ]
pkg_config = [ "proj-sys/pkg_config" ]
network = ["reqwest", "proj-sys/network"]
network = ["ureq", "proj-sys/network"]

[dev-dependencies]
# approx version must match the one used in geo-types
Expand Down
130 changes: 73 additions & 57 deletions src/network.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,13 @@
// This functionality based on https://github.com/OSGeo/PROJ/blob/master/src/networkfilemanager.cpp#L1675
use proj_sys::{proj_context_set_network_callbacks, PJ_CONTEXT, PROJ_NETWORK_HANDLE};

use reqwest::blocking::{Client, RequestBuilder, Response};
use reqwest::Method;
use std::collections::HashMap;
use std::ffi::CString;
use std::io::Read;
use std::ops::Range;
use std::os::raw::c_ulonglong;
use std::ptr::{self, NonNull};
use ureq::{Agent, Request, Response};

use crate::proj::{ProjError, _string};
use libc::c_char;
Expand All @@ -30,11 +32,14 @@ const CLIENT: &str = concat!("proj-rs/", env!("CARGO_PKG_VERSION"));
const MAX_RETRIES: u8 = 8;
// S3 sometimes sends these in place of actual client errors, so retry instead of erroring
const RETRY_CODES: [u16; 4] = [429, 500, 502, 504];
const SUCCESS_ERROR_CODES: Range<u16> = 200..300;
const CLIENT_ERROR_CODES: Range<u16> = 400..500;
const SERVER_ERROR_CODES: Range<u16> = 500..600;

/// This struct is cast to `c_void`, then to `PROJ_NETWORK_HANDLE` so it can be passed around
struct HandleData {
url: String,
headers: reqwest::header::HeaderMap,
headers: HashMap<String, String>,
// this raw pointer is handed out to libproj but never returned,
// so a copy of the pointer (raw pointers are Copy) is stored here.
// Note to future self: are you 100% sure that the pointer is never read again
Expand All @@ -43,11 +48,7 @@ struct HandleData {
}

impl HandleData {
fn new(
url: String,
headers: reqwest::header::HeaderMap,
hptr: Option<NonNull<c_char>>,
) -> Self {
fn new(url: String, headers: HashMap<String, String>, hptr: Option<NonNull<c_char>>) -> Self {
Self { url, headers, hptr }
}
}
Expand All @@ -74,36 +75,34 @@ fn get_wait_time_exp(retrycount: i32) -> u64 {

/// Process CDN response: handle retries in case of server error, or early return for client errors
/// Successful retry data is stored into res
fn error_handler(res: &mut Response, rb: RequestBuilder) -> Result<&Response, ProjError> {
let mut status = res.status().as_u16();
fn error_handler(res: &mut Response, rb: Request) -> Result<&Response, ProjError> {
let mut retries = 0;
// Check whether something went wrong on the server, or if it's an S3 retry code
if res.status().is_server_error() || RETRY_CODES.contains(&status) {
if SERVER_ERROR_CODES.contains(&res.status()) || RETRY_CODES.contains(&res.status()) {
// Start retrying: up to MAX_RETRIES
while (res.status().is_server_error() || RETRY_CODES.contains(&status))
while (SERVER_ERROR_CODES.contains(&res.status()) || RETRY_CODES.contains(&res.status()))
&& retries <= MAX_RETRIES
{
retries += 1;
let wait = time::Duration::from_millis(get_wait_time_exp(retries as i32));
thread::sleep(wait);
let retry = rb.try_clone().ok_or(ProjError::RequestCloneError)?;
*res = retry.send()?;
status = res.status().as_u16();
let retry = rb.clone();
*res = retry.call()?;
}
// Not a timeout or known S3 retry code: bail out
} else if res.status().is_client_error() {
} else if CLIENT_ERROR_CODES.contains(&res.status()) {
return Err(ProjError::DownloadError(
res.status().as_str().to_string(),
res.url().to_string(),
res.status_text().to_string(),
res.get_url().to_string(),
retries,
));
}
// Retries have been exhausted OR
// The loop ended prematurely due to a different error
if !res.status().is_success() {
if !SUCCESS_ERROR_CODES.contains(&res.status()) {
return Err(ProjError::DownloadError(
res.status().as_str().to_string(),
res.url().to_string(),
res.status_text().to_string(),
res.get_url().to_string(),
retries,
));
}
Expand Down Expand Up @@ -173,26 +172,35 @@ unsafe fn _network_open(
// RANGE header definition is "bytes=x-y"
let hvalue = format!("bytes={offset}-{end}");
// Create a new client that can be reused for subsequent queries
let clt = Client::builder().build()?;
let req = clt.request(Method::GET, &url);
// this performs the initial byte read, presumably as an error check
let initial = req.try_clone().ok_or(ProjError::RequestCloneError)?;
let with_headers = initial.header("Range", &hvalue).header("Client", CLIENT);
let mut res = with_headers.send()?;
let in_case_of_error = req
.try_clone()
.ok_or(ProjError::RequestCloneError)?
.header("Range", &hvalue);
let clt = Agent::new();
let req = clt.get(&url);
let with_headers = req.set("Range", &hvalue).set("Client", CLIENT);
let in_case_of_error = with_headers.clone();
let mut res = with_headers.call()?;
// hand the response off to the error-handler, continue on success
error_handler(&mut res, in_case_of_error)?;
// Write the initial read length value into the pointer
let contentlength = res.content_length().ok_or(ProjError::ContentLength)? as usize;
out_size_read.write(contentlength);
let headers = res.headers().clone();
let Some(Ok(contentlength)) = res.header("Content-Length").map(str::parse::<usize>) else {
return Err(ProjError::ContentLength);
};
let headers = res
.headers_names()
.into_iter()
.filter_map(|h| {
Some({
let v = res.header(&h)?.to_string();
(h, v)
})
})
.collect();
// Copy the downloaded bytes into the buffer so it can be passed around
res.bytes()?
.as_ptr()
.copy_to_nonoverlapping(buffer.cast(), contentlength.min(size_to_read));
let capacity = contentlength.min(size_to_read);
let mut buf = Vec::with_capacity(capacity);
res.into_reader()
.take(size_to_read as u64)
.read_to_end(&mut buf)?;
out_size_read.write(buf.len());
buf.as_ptr().copy_to_nonoverlapping(buffer.cast(), capacity);
let hd = HandleData::new(url, headers, None);
// heap-allocate the struct and cast it to a void pointer so it can be passed around to PROJ
let hd_boxed = Box::new(hd);
Expand Down Expand Up @@ -255,9 +263,8 @@ unsafe fn _network_get_header_value(
let hvalue = hd
.headers
.get(&lookup)
.ok_or_else(|| ProjError::HeaderError(lookup.to_string()))?
.to_str()?;
let cstr = CString::new(hvalue).unwrap();
.ok_or_else(|| ProjError::HeaderError(lookup.to_string()))?;
let cstr = CString::new(&**hvalue).unwrap();
let header = cstr.into_raw();
// Raw pointers are Copy: the pointer returned by this function is never returned by libproj so
// in order to avoid a memory leak the pointer is copied and stored in the HandleData struct,
Expand Down Expand Up @@ -327,34 +334,43 @@ fn _network_read_range(
let end = offset as usize + size_to_read - 1;
let hvalue = format!("bytes={offset}-{end}");
let hd = unsafe { &mut *(handle as *const c_void as *mut HandleData) };
let clt = Client::builder().build()?;
let initial = clt.request(Method::GET, &hd.url);
let in_case_of_error = initial
.try_clone()
.ok_or(ProjError::RequestCloneError)?
.header("Range", &hvalue)
.header("Client", CLIENT);
let req = in_case_of_error
.try_clone()
.ok_or(ProjError::RequestCloneError)?;
let mut res = req.send()?;
let clt = Agent::new();
let initial = clt.get(&hd.url);
let in_case_of_error = initial.clone().set("Range", &hvalue).set("Client", CLIENT);
let req = in_case_of_error.clone();
let mut res = req.call()?;
// hand the response and retry instance off to the error-handler, continue on success
error_handler(&mut res, in_case_of_error)?;
let headers = res.headers().clone();
let contentlength = res.content_length().ok_or(ProjError::ContentLength)? as usize;
let headers = res
.headers_names()
.into_iter()
.filter_map(|h| {
Some({
let v = res.header(&h)?.to_string();
(h, v)
})
})
.collect();
let Some(Ok(contentlength)) = res.header("Content-Length").map(str::parse::<usize>) else {
return Err(ProjError::ContentLength);
};
// Copy the downloaded bytes into the buffer so it can be passed around
let capacity = contentlength.min(size_to_read);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a pre-existing issue, but we aren't consistent about out_size_read. If Content-Length is larger than size_to_read, we set out_size_read to a value larger than the actual length of the buffer. We should just set it to buf.len(), regardless or how we handle that case.

Same for _network_open of course.

let mut buf = Vec::with_capacity(capacity);
res.into_reader()
.take(size_to_read as u64)
.read_to_end(&mut buf)?;
unsafe {
res.bytes()?
.as_ptr()
.copy_to_nonoverlapping(buffer.cast::<u8>(), contentlength.min(size_to_read));
buf.as_ptr()
.copy_to_nonoverlapping(buffer.cast::<u8>(), capacity);
}
let err_string = "";
unsafe {
out_error_string.copy_from_nonoverlapping(err_string.as_ptr().cast(), err_string.len());
out_error_string.add(err_string.len()).write(0);
}
hd.headers = headers;
Ok(contentlength)
Ok(buf.len())
}

/// Set up and initialise the grid download callback functions for all subsequent PROJ contexts
Expand Down
18 changes: 13 additions & 5 deletions src/proj.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ impl<T: CoordinateType> Coord<T> for (T, T) {

/// Errors originating in PROJ which can occur during projection and conversion
#[derive(Error, Debug)]
#[non_exhaustive]
pub enum ProjError {
/// A projection error
#[error("The projection failed with the following error: {0}")]
Expand All @@ -107,24 +108,31 @@ pub enum ProjError {
Network,
#[error("Could not set remote grid download callbacks")]
RemoteCallbacks,
#[error("Couldn't build request")]
#[error("Couldn't access the network")]
#[cfg(feature = "network")]
BuilderError(#[from] reqwest::Error),
NetworkError(Box<ureq::Error>),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's mark ProjError as #[non_exhaustive], so we don't take a breaking change next time we add a variant.

#[error("Couldn't clone request")]
RequestCloneError,
#[error("Could not retrieve content length")]
ContentLength,
#[error("Couldn't retrieve header for key {0}")]
HeaderError(String),
#[cfg(feature = "network")]
#[error("Couldn't convert header value to str")]
HeaderConversion(#[from] reqwest::header::ToStrError),
#[error("Couldn't read response to buffer")]
ReadError(#[from] std::io::Error),
#[error("A {0} error occurred for url {1} after {2} retries")]
DownloadError(String, String, u8),
#[error("The current definition could not be retrieved")]
Definition,
}

#[cfg(feature = "network")]
impl From<ureq::Error> for ProjError {
fn from(e: ureq::Error) -> Self {
Self::NetworkError(Box::new(e))
}
}

#[derive(Error, Debug)]
pub enum ProjCreateError {
#[error("A nul byte was found in the PROJ string definition or CRS argument: {0}")]
Expand Down Expand Up @@ -1461,7 +1469,7 @@ mod test {
let usa_m = MyPoint::new(-115.797615, 37.2647978);
let usa_ft = to_feet.convert(usa_m).unwrap();
assert_relative_eq!(6693625.67217475, usa_ft.x());
assert_relative_eq!(3497301.5918027232, usa_ft.y(), epsilon=1e-8);
assert_relative_eq!(3497301.5918027232, usa_ft.y(), epsilon = 1e-8);
}

#[test]
Expand Down
Loading