Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Custom HTTP Response abstraction #923

Merged
merged 7 commits into from
Jul 3, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 26 additions & 16 deletions gateway/queue/src/day_limiter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,12 @@ impl DayLimiter {
.gateway()
.authed()
.await
.map_err(|source| DayLimiterError {
kind: DayLimiterErrorType::RetrievingSessionAvailability,
source: Some(Box::new(source)),
})?
.model()
.await
.map_err(|source| DayLimiterError {
kind: DayLimiterErrorType::RetrievingSessionAvailability,
source: Some(Box::new(source)),
Expand Down Expand Up @@ -88,23 +94,27 @@ impl DayLimiter {
} else {
let wait = lock.last_check + lock.next_reset;
time::sleep_until(wait).await;
if let Ok(info) = lock.http.gateway().authed().await {
let last_check = Instant::now();
let next_reset = Duration::from_millis(info.session_start_limit.remaining);
tracing::info!("next session start limit reset in: {:.2?}", next_reset);
let total = info.session_start_limit.total;
let remaining = info.session_start_limit.remaining;
assert!(total >= remaining);
let current = total - remaining;
lock.last_check = last_check;
lock.next_reset = next_reset;
lock.total = total;
lock.current = current + 1;
} else {
tracing::warn!(
"unable to get new session limits, skipping (this may cause bad things)"
)
if let Ok(res) = lock.http.gateway().authed().await {
if let Ok(info) = res.model().await {
let last_check = Instant::now();
let next_reset = Duration::from_millis(info.session_start_limit.remaining);
tracing::info!("next session start limit reset in: {:.2?}", next_reset);
zeylahellyer marked this conversation as resolved.
Show resolved Hide resolved
let total = info.session_start_limit.total;
let remaining = info.session_start_limit.remaining;
assert!(total >= remaining);
let current = total - remaining;
lock.last_check = last_check;
lock.next_reset = next_reset;
lock.total = total;
lock.current = current + 1;

return;
}
}

tracing::warn!(
"unable to get new session limits, skipping (this may cause bad things)"
);
zeylahellyer marked this conversation as resolved.
Show resolved Hide resolved
}
}
}
17 changes: 7 additions & 10 deletions gateway/src/cluster/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -86,16 +86,13 @@ impl ClusterBuilder {
ClusterStartError,
> {
if (self.1).0.gateway_url.is_none() {
let gateway_url = (self.1)
.0
.http_client
.gateway()
.authed()
.await
.ok()
.map(|s| s.url);

self = self.gateway_url(gateway_url);
let maybe_response = (self.1).0.http_client.gateway().authed().await;

if let Ok(response) = maybe_response {
let gateway_url = response.model().await.ok().map(|info| info.url);

self = self.gateway_url(gateway_url);
}
}

self.0.shard_config = (self.1).0;
Expand Down
6 changes: 6 additions & 0 deletions gateway/src/cluster/impl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -376,6 +376,12 @@ impl Cluster {
.gateway()
.authed()
.await
.map_err(|source| ClusterStartError {
kind: ClusterStartErrorType::RetrievingGatewayInfo,
source: Some(Box::new(source)),
})?
.model()
.await
.map_err(|source| ClusterStartError {
kind: ClusterStartErrorType::RetrievingGatewayInfo,
source: Some(Box::new(source)),
Expand Down
6 changes: 6 additions & 0 deletions gateway/src/shard/impl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -456,6 +456,12 @@ impl Shard {
source: Some(Box::new(source)),
kind: ShardStartErrorType::RetrievingGatewayUrl,
})?
.model()
.await
.map_err(|source| ShardStartError {
source: Some(Box::new(source)),
kind: ShardStartErrorType::RetrievingGatewayUrl,
})?
.url
};

Expand Down
2 changes: 1 addition & 1 deletion http/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ percent-encoding = { default-features = false, version = "2" }
tokio = { default-features = false, features = ["time"], version = "1.0" }
twilight-model = { default-features = false, path = "../model" }
serde = { default-features = false, features = ["derive"], version = "1" }
serde_json = { default-features = false, features = ["alloc"], version = "1" }
serde_json = { default-features = false, features = ["std"], version = "1" }

# optional
simd-json = { default-features = false, features = ["serde_impl", "swar-number-parsing"], optional = true, version = "0.4" }
Expand Down
2 changes: 1 addition & 1 deletion http/examples/get-message/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ async fn main() -> Result<(), Box<dyn Error + Send + Sync>> {
}))
.await;

let me = client.current_user().await?;
let me = client.current_user().await?.model().await?;
println!("Current user: {}#{}", me.name, me.discriminator);

Ok(())
Expand Down
2 changes: 1 addition & 1 deletion http/examples/proxy/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ async fn main() -> Result<(), Box<dyn Error + Send + Sync>> {
}))
.await;

let me = client.current_user().await?;
let me = client.current_user().await?.model().await?;
println!("Current user: {}#{}", me.name, me.discriminator);

Ok(())
Expand Down
93 changes: 23 additions & 70 deletions http/src/client/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,16 +23,15 @@ use crate::{
prelude::*,
GetUserApplicationInfo, Method, Request,
},
response::{Response, StatusCode},
API_VERSION,
};
use hyper::body::Bytes;
use hyper::{
body,
client::{Client as HyperClient, HttpConnector},
header::{HeaderMap, HeaderValue, AUTHORIZATION, CONTENT_LENGTH, CONTENT_TYPE, USER_AGENT},
Body, Response, StatusCode,
Body, StatusCode as HyperStatusCode,
};
use serde::de::DeserializeOwned;
use std::{
convert::TryFrom,
fmt::{Debug, Formatter, Result as FmtResult},
Expand Down Expand Up @@ -652,7 +651,7 @@ impl Client {
/// # async fn main() -> Result<(), Box<dyn std::error::Error>> {
/// # let client = Client::new("my token");
/// #
/// let info = client.gateway().authed().await?;
/// let info = client.gateway().authed().await?.model().await?;
///
/// println!("URL: {}", info.url);
/// println!("Recommended shards to use: {}", info.shards);
Expand Down Expand Up @@ -907,6 +906,8 @@ impl Client {
/// let member = client.update_guild_member(GuildId(1), UserId(2))
/// .mute(true)
/// .nick(Some("pinkie pie".to_owned()))?
/// .await?
/// .model()
/// .await?;
///
/// println!("user {} now has the nickname '{:?}'", member.user.id, member.nick);
Expand Down Expand Up @@ -2120,7 +2121,7 @@ impl Client {
/// Returns an [`ErrorType::Unauthorized`] error type if the configured
/// token has become invalid due to expiration, revokation, etc.
#[allow(clippy::too_many_lines)]
pub async fn raw(&self, request: Request) -> Result<Response<Body>, Error> {
pub async fn request<T>(&self, request: Request) -> Result<Response<T>, Error> {
if self.state.token_invalid.load(Ordering::Relaxed) {
return Err(Error {
kind: ErrorType::Unauthorized,
Expand Down Expand Up @@ -2243,16 +2244,17 @@ impl Client {
let ratelimiter = match self.state.ratelimiter.as_ref() {
Some(ratelimiter) => ratelimiter,
None => {
return fut
.await
.map_err(|source| Error {
kind: ErrorType::RequestTimedOut,
source: Some(Box::new(source)),
})?
.map_err(|source| Error {
kind: ErrorType::RequestError,
source: Some(Box::new(source)),
});
return Ok(Response::new(
fut.await
.map_err(|source| Error {
kind: ErrorType::RequestTimedOut,
source: Some(Box::new(source)),
})?
.map_err(|source| Error {
kind: ErrorType::RequestError,
source: Some(Box::new(source)),
})?,
));
}
};

Expand All @@ -2276,7 +2278,7 @@ impl Client {
// If the API sent back an Unauthorized response, then the client's
// configured token is permanently invalid and future requests must be
// ignored to avoid API bans.
if resp.status() == StatusCode::UNAUTHORIZED {
if resp.status() == HyperStatusCode::UNAUTHORIZED {
self.state.token_invalid.store(true, Ordering::Relaxed);
}

Expand All @@ -2293,74 +2295,25 @@ impl Client {
}
}

Ok(resp)
}

/// Execute a request, chunking and deserializing the response.
///
/// # Errors
///
/// Returns an [`ErrorType::Unauthorized`] error type if the configured
/// token has become invalid due to expiration, revokation, etc.
pub async fn request<T: DeserializeOwned>(&self, request: Request) -> Result<T, Error> {
let resp = self.make_request(request).await?;

let bytes = body::to_bytes(resp.into_body())
.await
.map_err(|source| Error {
kind: ErrorType::ChunkingResponse,
source: Some(Box::new(source)),
})?;

crate::json::parse_bytes(&bytes)
}

pub(crate) async fn request_bytes(&self, request: Request) -> Result<Bytes, Error> {
let resp = self.make_request(request).await?;

hyper::body::to_bytes(resp.into_body())
.await
.map_err(|source| Error {
kind: ErrorType::ChunkingResponse,
source: Some(Box::new(source)),
})
}

/// Execute a request, checking only that the response was a success.
///
/// This will not chunk and deserialize the body of the response.
///
/// # Errors
///
/// Returns an [`ErrorType::Unauthorized`] error type if the configured
/// token has become invalid due to expiration, revokation, etc.
pub async fn verify(&self, request: Request) -> Result<(), Error> {
self.make_request(request).await?;

Ok(())
}

async fn make_request(&self, request: Request) -> Result<Response<Body>, Error> {
let resp = self.raw(request).await?;
let status = resp.status();

if status.is_success() {
return Ok(resp);
return Ok(Response::new(resp));
}

match status {
StatusCode::IM_A_TEAPOT => {
HyperStatusCode::IM_A_TEAPOT => {
#[cfg(feature = "tracing")]
tracing::warn!(
"discord's api now runs off of teapots -- proceed to panic: {:?}",
resp,
);
}
StatusCode::TOO_MANY_REQUESTS => {
HyperStatusCode::TOO_MANY_REQUESTS => {
#[cfg(feature = "tracing")]
tracing::warn!("429 response: {:?}", resp);
}
StatusCode::SERVICE_UNAVAILABLE => {
HyperStatusCode::SERVICE_UNAVAILABLE => {
return Err(Error {
kind: ErrorType::ServiceUnavailable { response: resp },
source: None,
Expand Down Expand Up @@ -2391,7 +2344,7 @@ impl Client {
kind: ErrorType::Response {
body: bytes.to_vec(),
error,
status,
status: StatusCode::new(status.as_u16()),
},
source: None,
})
Expand Down
9 changes: 2 additions & 7 deletions http/src/error.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,10 @@
use crate::api_error::ApiError;
use hyper::{Body, Response, StatusCode};
use crate::{api_error::ApiError, json::JsonError, response::StatusCode};
use hyper::{Body, Response};
use std::{
error::Error as StdError,
fmt::{Debug, Display, Formatter, Result as FmtResult},
};

#[cfg(not(feature = "simd-json"))]
use serde_json::Error as JsonError;
#[cfg(feature = "simd-json")]
use simd_json::Error as JsonError;

#[derive(Debug)]
pub struct Error {
pub(super) source: Option<Box<dyn StdError + Send + Sync>>,
Expand Down
4 changes: 2 additions & 2 deletions http/src/json.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#[cfg(not(feature = "simd-json"))]
pub use serde_json::to_vec;
pub use serde_json::{to_vec, Deserializer as JsonDeserializer, Error as JsonError};
#[cfg(feature = "simd-json")]
pub use simd_json::to_vec;
pub use simd_json::{to_vec, Deserializer as JsonDeserializer, Error as JsonError};

use crate::error::{Error, ErrorType};
use hyper::body::Bytes;
Expand Down
3 changes: 2 additions & 1 deletion http/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -105,14 +105,15 @@ pub mod client;
pub mod error;
pub mod ratelimiting;
pub mod request;
pub mod response;
pub mod routing;

mod json;

/// Discord API version used by this crate.
pub const API_VERSION: u8 = 8;

pub use crate::{client::Client, error::Error};
pub use crate::{client::Client, error::Error, response::Response};

#[cfg(not(any(
feature = "native",
Expand Down
Loading