Skip to content

Commit

Permalink
extract middleware_fn json_response from HttpSessionMiddleware.
Browse files Browse the repository at this point in the history
  • Loading branch information
youngsofun committed Aug 12, 2024
1 parent 53793b8 commit a64311f
Show file tree
Hide file tree
Showing 5 changed files with 72 additions and 37 deletions.
35 changes: 27 additions & 8 deletions src/query/service/src/servers/http/http_services.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ use poem::IntoEndpoint;
use poem::Route;

use super::v1::upload_to_stage;
use crate::servers::http::middleware::json_response;
use crate::servers::http::middleware::EndpointKind;
use crate::servers::http::middleware::HTTPSessionMiddleware;
use crate::servers::http::middleware::PanicHandler;
Expand Down Expand Up @@ -86,7 +87,7 @@ impl HttpHandler {
})
}

fn wrap_auth<E>(&self, ep: E, auth_type: EndpointKind) -> impl Endpoint
pub fn wrap_auth<E>(&self, ep: E, auth_type: EndpointKind) -> impl Endpoint
where
E: IntoEndpoint,
E::Endpoint: 'static,
Expand All @@ -99,26 +100,43 @@ impl HttpHandler {
#[async_backtrace::framed]
async fn build_router(&self, sock: SocketAddr) -> impl Endpoint {
let ep_v1 = Route::new()
.nest("/query", self.wrap_auth(query_route(), EndpointKind::Query))
.nest("/query", query_route(self.kind))
.at(
"/session/login",
self.wrap_auth(post(login_handler), EndpointKind::Login),
post(login_handler).with(HTTPSessionMiddleware::create(
self.kind,
EndpointKind::Login,
)),
)
.at(
"/session/renew",
self.wrap_auth(post(renew_handler), EndpointKind::Refresh),
post(renew_handler).with(HTTPSessionMiddleware::create(
self.kind,
EndpointKind::Refresh,
)),
)
.at(
"/upload_to_stage",
self.wrap_auth(put(upload_to_stage), EndpointKind::Query),
put(upload_to_stage).with(HTTPSessionMiddleware::create(
self.kind,
EndpointKind::StartQuery,
)),
)
.at(
"/suggested_background_tasks",
self.wrap_auth(get(list_suggestions), EndpointKind::Query),
get(list_suggestions).with(HTTPSessionMiddleware::create(
self.kind,
EndpointKind::StartQuery,
)),
);

let ep_clickhouse = Route::new().nest("/", clickhouse_router());
let ep_clickhouse = self.wrap_auth(ep_clickhouse, EndpointKind::Clickhouse);
let ep_clickhouse =
Route::new()
.nest("/", clickhouse_router())
.with(HTTPSessionMiddleware::create(
self.kind,
EndpointKind::Clickhouse,
));

let ep_usage = Route::new().at(
"/",
Expand All @@ -140,6 +158,7 @@ impl HttpHandler {
};
ep.with(NormalizePath::new(TrailingSlash::Trim))
.with(CatchPanic::new().with_handler(PanicHandler::new()))
.around(json_response)
.boxed()
}

Expand Down
46 changes: 27 additions & 19 deletions src/query/service/src/servers/http/middleware.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,8 @@ use crate::sessions::SessionType;
pub enum EndpointKind {
Login,
Refresh,
Query,
StartQuery,
PollQuery,
Clickhouse,
}

Expand Down Expand Up @@ -207,10 +208,10 @@ fn auth_by_header(
let token = bearer.token().to_string();
if SessionClaim::is_databend_token(&token) {
let (token_type, set_user) = match endpoint_kind {
EndpointKind::Login => (TokenType::Session, false),
EndpointKind::Refresh => (TokenType::Refresh, true),
EndpointKind::Query => (TokenType::Session, true),
EndpointKind::Clickhouse => {
EndpointKind::StartQuery => (TokenType::Session, true),
EndpointKind::PollQuery => (TokenType::Session, false),
_ => {
return Err(ErrorCode::AuthenticateFailure(format!(
"should not use databend auth when accessing {path}"
)));
Expand Down Expand Up @@ -367,10 +368,10 @@ impl<E: Endpoint> Endpoint for HTTPSessionEndpoint<E> {
let _guard = ThreadTracker::tracking(tracking_payload);

ThreadTracker::tracking_future(async move {
let res = match self.auth(&req, query_id).await {
match self.auth(&req, query_id).await {
Ok(ctx) => {
req.extensions_mut().insert(ctx);
self.ep.call(req).await
self.ep.call(req).await.map(|v| v.into_response())
}
Err(err) => match err.code() {
ErrorCode::AUTHENTICATE_FAILURE
Expand Down Expand Up @@ -401,19 +402,6 @@ impl<E: Endpoint> Endpoint for HTTPSessionEndpoint<E> {
))
}
},
};
match res {
Err(err) => {
let body = Body::from_json(serde_json::json!({
"error": {
"code": err.status().as_str(),
"message": err.to_string(),
}
}))
.unwrap();
Ok(Response::builder().status(err.status()).body(body))
}
Ok(res) => Ok(res.into_response()),
}
})
.await
Expand Down Expand Up @@ -498,3 +486,23 @@ impl poem::middleware::PanicHandler for PanicHandler {
(StatusCode::INTERNAL_SERVER_ERROR, "internal server error")
}
}
pub async fn json_response<E: Endpoint>(next: E, req: Request) -> PoemResult<Response> {
let res = next.call(req).await;

match res {
Ok(resp) => {
let resp = resp.into_response();
Ok(resp)
}
Err(err) => {
let body = Body::from_json(serde_json::json!({
"error": {
"code": err.status().as_str(),
"message": err.to_string(),
}
}))
.unwrap();
Ok(Response::builder().status(err.status()).body(body))
}
}
}
17 changes: 15 additions & 2 deletions src/query/service/src/servers/http/v1/http_query_handlers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,15 @@ use super::query::ExecuteStateKind;
use super::query::HttpQueryRequest;
use super::query::HttpQueryResponseInternal;
use super::query::RemoveReason;
use crate::servers::http::middleware::EndpointKind;
use crate::servers::http::middleware::HTTPSessionMiddleware;
use crate::servers::http::middleware::MetricsMiddleware;
use crate::servers::http::v1::query::Progresses;
use crate::servers::http::v1::HttpQueryContext;
use crate::servers::http::v1::HttpQueryManager;
use crate::servers::http::v1::HttpSessionConf;
use crate::servers::http::v1::StringBlock;
use crate::servers::HttpHandlerKind;
use crate::sessions::QueryAffect;

pub fn make_page_uri(query_id: &str, page_no: usize) -> String {
Expand Down Expand Up @@ -396,7 +399,7 @@ pub(crate) async fn query_handler(
.await
}

pub fn query_route() -> Route {
pub fn query_route(http_handler_kind: HttpHandlerKind) -> Route {
// Note: endpoints except /v1/query may change without notice, use uris in response instead
let rules = [
("/", post(query_handler)),
Expand All @@ -414,7 +417,17 @@ pub fn query_route() -> Route {

let mut route = Route::new();
for (path, endpoint) in rules.into_iter() {
route = route.at(path, endpoint.with(MetricsMiddleware::new(path)));
let kind = if path == "/" {
EndpointKind::StartQuery
} else {
EndpointKind::PollQuery
};
route = route.at(
path,
endpoint
.with(MetricsMiddleware::new(path))
.with(HTTPSessionMiddleware::create(http_handler_kind, kind)),
);
}
route
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ impl SuggestedBackgroundTasksSource {
info!(
background = true,
tenant = ctx.get_tenant().tenant_name().to_string();
"list all lsuggestions"
"list all suggestions"
);
Self::get_suggested_compaction_tasks(ctx).await
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ use wiremock::ResponseTemplate;

use crate::tests::tls_constants::*;

type EndpointType = HTTPSessionEndpoint<Route>;
type EndpointType = Route;

fn unwrap_data<'a>(data: &'a [Vec<Option<String>>], null_as: &'a str) -> Vec<Vec<&'a str>> {
data.iter()
Expand Down Expand Up @@ -857,12 +857,7 @@ async fn post_sql(sql: &str, wait_time_secs: u64) -> Result<(StatusCode, QueryRe
}

pub fn create_endpoint() -> Result<EndpointType> {
let session_middleware =
HTTPSessionMiddleware::create(HttpHandlerKind::Query, EndpointKind::Query);

Ok(Route::new()
.nest("/v1/query", query_route())
.with(session_middleware))
Ok(Route::new().nest("/v1/query", query_route(HttpHandlerKind::Query)))
}

async fn post_json(json: &serde_json::Value) -> Result<(StatusCode, QueryResponse)> {
Expand Down

0 comments on commit a64311f

Please sign in to comment.