Skip to content

Commit

Permalink
avoid use url path directly.
Browse files Browse the repository at this point in the history
  • Loading branch information
youngsofun committed Aug 12, 2024
1 parent 5a0400f commit b8cfada
Show file tree
Hide file tree
Showing 5 changed files with 89 additions and 78 deletions.
4 changes: 2 additions & 2 deletions src/query/service/src/auth.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ pub struct AuthMgr {
}

pub enum Credential {
Databend {
DatabendToken {
token: String,
token_type: TokenType,
set_user: bool,
Expand Down Expand Up @@ -75,7 +75,7 @@ impl AuthMgr {
pub async fn auth(&self, session: &mut Session, credential: &Credential) -> Result<()> {
let user_api = UserApiProvider::instance();
match credential {
Credential::Databend {
Credential::DatabendToken {
token,
set_user,
token_type,
Expand Down
37 changes: 26 additions & 11 deletions src/query/service/src/servers/http/http_services.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,11 @@ use poem::post;
use poem::put;
use poem::Endpoint;
use poem::EndpointExt;
use poem::IntoEndpoint;
use poem::Route;

use super::v1::upload_to_stage;
use crate::auth::AuthMgr;
use crate::servers::http::middleware::EndpointKind;
use crate::servers::http::middleware::HTTPSessionMiddleware;
use crate::servers::http::middleware::PanicHandler;
use crate::servers::http::v1::clickhouse_router;
Expand Down Expand Up @@ -85,25 +86,39 @@ impl HttpHandler {
})
}

fn wrap_auth(&self, ep: Route) -> impl Endpoint {
let auth_manager = AuthMgr::instance();
let session_middleware = HTTPSessionMiddleware::create(self.kind, auth_manager);
fn wrap_auth<E>(&self, ep: E, auth_type: EndpointKind) -> impl Endpoint
where
E: IntoEndpoint,
E::Endpoint: 'static,
{
let session_middleware = HTTPSessionMiddleware::create(self.kind, auth_type);
ep.with(session_middleware).boxed()
}

#[allow(clippy::let_with_type_underscore)]
#[async_backtrace::framed]
async fn build_router(&self, sock: SocketAddr) -> impl Endpoint {
let ep_v1 = Route::new()
.nest("/query", query_route())
.at("/session/login", post(login_handler))
.at("/session/renew", post(renew_handler))
.at("/upload_to_stage", put(upload_to_stage))
.at("/suggested_background_tasks", get(list_suggestions));
let ep_v1 = self.wrap_auth(ep_v1);
.nest("/query", self.wrap_auth(query_route(), EndpointKind::Query))
.at(
"/session/login",
self.wrap_auth(post(login_handler), EndpointKind::Login),
)
.at(
"/session/renew",
self.wrap_auth(post(renew_handler), EndpointKind::Refresh),
)
.at(
"/upload_to_stage",
self.wrap_auth(put(upload_to_stage), EndpointKind::Query),
)
.at(
"/suggested_background_tasks",
self.wrap_auth(get(list_suggestions), EndpointKind::Query),
);

let ep_clickhouse = Route::new().nest("/", clickhouse_router());
let ep_clickhouse = self.wrap_auth(ep_clickhouse);
let ep_clickhouse = self.wrap_auth(ep_clickhouse, EndpointKind::Clickhouse);

let ep_usage = Route::new().at(
"/",
Expand Down
59 changes: 42 additions & 17 deletions src/query/service/src/servers/http/middleware.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,17 +64,31 @@ use crate::servers::HttpHandlerKind;
use crate::sessions::SessionManager;
use crate::sessions::SessionType;

#[derive(Copy, Clone)]
pub enum EndpointKind {
Login,
Refresh,
Query,
Clickhouse,
}

const USER_AGENT: &str = "User-Agent";
const TRACE_PARENT: &str = "traceparent";

pub struct HTTPSessionMiddleware {
pub kind: HttpHandlerKind,
pub endpoint_kind: EndpointKind,
pub auth_manager: Arc<AuthMgr>,
}

impl HTTPSessionMiddleware {
pub fn create(kind: HttpHandlerKind, auth_manager: Arc<AuthMgr>) -> HTTPSessionMiddleware {
HTTPSessionMiddleware { kind, auth_manager }
pub fn create(kind: HttpHandlerKind, endpoint_kind: EndpointKind) -> HTTPSessionMiddleware {
let auth_manager = AuthMgr::instance();
HTTPSessionMiddleware {
kind,
endpoint_kind,
auth_manager,
}
}
}

Expand Down Expand Up @@ -109,7 +123,11 @@ fn extract_baggage_from_headers(headers: &HeaderMap) -> Option<Vec<(String, Stri
Some(result)
}

fn get_credential(req: &Request, kind: HttpHandlerKind) -> Result<Credential> {
fn get_credential(
req: &Request,
kind: HttpHandlerKind,
endpoint_kind: EndpointKind,
) -> Result<Credential> {
let std_auth_headers: Vec<_> = req.headers().get_all(AUTHORIZATION).iter().collect();
if std_auth_headers.len() > 1 {
let msg = &format!("Multiple {} headers detected", AUTHORIZATION);
Expand All @@ -125,7 +143,12 @@ fn get_credential(req: &Request, kind: HttpHandlerKind) -> Result<Credential> {
))
}
} else {
auth_by_header(&std_auth_headers, client_ip, req.uri().path())
auth_by_header(
&std_auth_headers,
client_ip,
endpoint_kind,
req.uri().path(),
)
}
}

Expand Down Expand Up @@ -159,6 +182,7 @@ pub fn get_client_ip(req: &Request) -> Option<String> {
fn auth_by_header(
std_auth_headers: &[&HeaderValue],
client_ip: Option<String>,
endpoint_kind: EndpointKind,
path: &str,
) -> Result<Credential> {
let value = &std_auth_headers[0];
Expand All @@ -182,18 +206,17 @@ fn auth_by_header(
Some(bearer) => {
let token = bearer.token().to_string();
if SessionClaim::is_databend_token(&token) {
let (token_type, set_user) = if path == "/query" {
(TokenType::Session, true)
} else if path == "/session/renew" {
(TokenType::Refresh, true)
} else if path != "/session/login" {
(TokenType::Session, false)
} else {
return Err(ErrorCode::AuthenticateFailure(format!(
"should not use databend auth when accessing {path}"
)));
let (token_type, set_user) = match endpoint_kind {
EndpointKind::Login => (TokenType::Session, false),
EndpointKind::Refresh => (TokenType::Refresh, true),
EndpointKind::Query => (TokenType::Session, true),
EndpointKind::Clickhouse => {
return Err(ErrorCode::AuthenticateFailure(format!(
"should not use databend auth when accessing {path}"
)));
}
};
Ok(Credential::Databend {
Ok(Credential::DatabendToken {
token,
token_type,
set_user,
Expand Down Expand Up @@ -246,6 +269,7 @@ impl<E: Endpoint> Middleware<E> for HTTPSessionMiddleware {
HTTPSessionEndpoint {
ep,
kind: self.kind,
endpoint_kind: self.endpoint_kind,
auth_manager: self.auth_manager.clone(),
}
}
Expand All @@ -254,13 +278,14 @@ impl<E: Endpoint> Middleware<E> for HTTPSessionMiddleware {
pub struct HTTPSessionEndpoint<E> {
ep: E,
pub kind: HttpHandlerKind,
pub endpoint_kind: EndpointKind,
pub auth_manager: Arc<AuthMgr>,
}

impl<E> HTTPSessionEndpoint<E> {
#[async_backtrace::framed]
async fn auth(&self, req: &Request, query_id: String) -> Result<HttpQueryContext> {
let credential = get_credential(req, self.kind)?;
let credential = get_credential(req, self.kind, self.endpoint_kind)?;

let session_manager = SessionManager::instance();

Expand All @@ -274,7 +299,7 @@ impl<E> HTTPSessionEndpoint<E> {

self.auth_manager.auth(&mut session, &credential).await?;
let databend_token = match credential {
Credential::Databend { token, .. } => Some(token),
Credential::DatabendToken { token, .. } => Some(token),
_ => None,
};

Expand Down
4 changes: 2 additions & 2 deletions src/query/service/tests/it/servers/http/clickhouse_handler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
use std::collections::HashMap;

use databend_common_base::base::tokio;
use databend_query::auth::AuthMgr;
use databend_query::servers::http::middleware::EndpointKind;
use databend_query::servers::http::middleware::HTTPSessionEndpoint;
use databend_query::servers::http::middleware::HTTPSessionMiddleware;
use databend_query::servers::http::v1::clickhouse_router;
Expand Down Expand Up @@ -321,7 +321,7 @@ struct Server {
impl Server {
pub async fn new() -> Self {
let session_middleware =
HTTPSessionMiddleware::create(HttpHandlerKind::Clickhouse, AuthMgr::instance());
HTTPSessionMiddleware::create(HttpHandlerKind::Clickhouse, EndpointKind::Clickhouse);
let endpoint = Route::new()
.nest("/", clickhouse_router())
.with(session_middleware);
Expand Down
Loading

0 comments on commit b8cfada

Please sign in to comment.