Skip to content

Commit

Permalink
feat: support server side session (#479)
Browse files Browse the repository at this point in the history
* add token related error code.

* add login/logout/renew related structs.

* rename Error::SessionTimeout to  Error::QueryExpired

* add Error::AuthFailure.

* feat: use session_token by default.

* feat: support cookie.

* refactor: make build request sync.

* refactor login and logout.

* support disable_session_token

* add License.

* fix typo

* rm dep async-trait.

* allow Unicode-3.0

* test temp table

* adjustment against the new changes in server side.

* test change_password

* fix fmt

* fix cargo deny

* print server version

* test with nightly image

* flight sql skip tests.

* support  ArrowDataType::Utf8View.

* update test
  • Loading branch information
youngsofun authored Nov 13, 2024
1 parent d1a1d29 commit 6849bbe
Show file tree
Hide file tree
Showing 26 changed files with 847 additions and 259 deletions.
2 changes: 1 addition & 1 deletion bindings/nodejs/tests/binding.js
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ Then("Stream load and Select should be equal", async function () {
];
const progress = await this.conn.streamLoad(`INSERT INTO test VALUES`, values);
assert.equal(progress.writeRows, 3);
assert.equal(progress.writeBytes, 185);
assert.equal(progress.writeBytes, 187);

const rows = await this.conn.queryIter("SELECT * FROM test");
const ret = [];
Expand Down
2 changes: 1 addition & 1 deletion bindings/python/tests/asyncio/steps/binding.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ async def _(context):
]
progress = await context.conn.stream_load("INSERT INTO test VALUES", values)
assert progress.write_rows == 3, f"progress.write_rows: {progress.write_rows}"
assert progress.write_bytes == 185, f"progress.write_bytes: {progress.write_bytes}"
assert progress.write_bytes == 187, f"progress.write_bytes: {progress.write_bytes}"

rows = await context.conn.query_iter("SELECT * FROM test")
ret = []
Expand Down
2 changes: 1 addition & 1 deletion bindings/python/tests/blocking/steps/binding.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def _(context):
]
progress = context.conn.stream_load("INSERT INTO test VALUES", values)
assert progress.write_rows == 3, f"progress.write_rows: {progress.write_rows}"
assert progress.write_bytes == 185, f"progress.write_bytes: {progress.write_bytes}"
assert progress.write_bytes == 187, f"progress.write_bytes: {progress.write_bytes}"

rows = context.conn.query_iter("SELECT * FROM test")
ret = []
Expand Down
10 changes: 3 additions & 7 deletions cli/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -373,13 +373,9 @@ pub async fn main() -> Result<()> {
// Exit client if user login failed.
if let Some(error) = err.downcast_ref::<databend_driver::Error>() {
match error {
databend_driver::Error::Api(
databend_client::error::Error::InvalidResponse(resp_err),
) => {
if resp_err.code == 401 {
println!("Authenticate failed wrong password user {}", user);
return Ok(());
}
databend_driver::Error::Api(databend_client::error::Error::AuthFailure(_)) => {
println!("Authenticate failed wrong password user {}", user);
return Ok(());
}
databend_driver::Error::Arrow(arrow::error::ArrowError::IpcError(ipc_err)) => {
if ipc_err.contains("Unauthenticated") {
Expand Down
6 changes: 2 additions & 4 deletions cli/src/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -106,11 +106,9 @@ impl Session {
Err(err) => {
match err {
databend_driver::Error::Api(
databend_client::error::Error::InvalidResponse(ref resp_err),
databend_client::error::Error::AuthFailure(_),
) => {
if resp_err.code == 401 {
return Err(err.into());
}
return Err(err.into());
}
databend_driver::Error::Arrow(arrow::error::ArrowError::IpcError(
ref ipc_err,
Expand Down
4 changes: 2 additions & 2 deletions core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,12 @@ native-tls = ["reqwest/native-tls"]
[dependencies]
tokio-stream = { workspace = true }

async-trait = "0.1"
cookie = "0.18.1"
log = "0.4"
once_cell = "1.18"
parking_lot = "0.12.3"
percent-encoding = "2.3"
reqwest = { version = "0.12", default-features = false, features = ["json", "multipart", "stream"] }
reqwest = { version = "0.12", default-features = false, features = ["json", "multipart", "stream", "cookies"] }
serde = { version = "1.0", default-features = false, features = ["derive"] }
serde_json = { version = "1.0", default-features = false, features = ["std"] }
tokio = { version = "1.34", features = ["macros"] }
Expand Down
33 changes: 17 additions & 16 deletions core/src/auth.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,11 @@ use reqwest::RequestBuilder;

use crate::error::{Error, Result};

#[async_trait::async_trait]
pub trait Auth: Sync + Send {
async fn wrap(&self, builder: RequestBuilder) -> Result<RequestBuilder>;
fn wrap(&self, builder: RequestBuilder) -> Result<RequestBuilder>;
fn can_reload(&self) -> bool {
false
}
fn username(&self) -> String;
}

Expand All @@ -37,9 +39,8 @@ impl BasicAuth {
}
}

#[async_trait::async_trait]
impl Auth for BasicAuth {
async fn wrap(&self, builder: RequestBuilder) -> Result<RequestBuilder> {
fn wrap(&self, builder: RequestBuilder) -> Result<RequestBuilder> {
Ok(builder.basic_auth(&self.username, Some(self.password.inner())))
}

Expand All @@ -61,9 +62,8 @@ impl AccessTokenAuth {
}
}

#[async_trait::async_trait]
impl Auth for AccessTokenAuth {
async fn wrap(&self, builder: RequestBuilder) -> Result<RequestBuilder> {
fn wrap(&self, builder: RequestBuilder) -> Result<RequestBuilder> {
Ok(builder.bearer_auth(self.token.inner()))
}

Expand All @@ -84,20 +84,21 @@ impl AccessTokenFileAuth {
}
}

#[async_trait::async_trait]
impl Auth for AccessTokenFileAuth {
async fn wrap(&self, builder: RequestBuilder) -> Result<RequestBuilder> {
let token = tokio::fs::read_to_string(&self.token_file)
.await
.map_err(|e| {
Error::IO(format!(
"cannot read access token from file {}: {}",
self.token_file, e
))
})?;
fn wrap(&self, builder: RequestBuilder) -> Result<RequestBuilder> {
let token = std::fs::read_to_string(&self.token_file).map_err(|e| {
Error::IO(format!(
"cannot read access token from file {}: {}",
self.token_file, e
))
})?;
Ok(builder.bearer_auth(token.trim()))
}

fn can_reload(&self) -> bool {
true
}

fn username(&self) -> String {
"token".to_string()
}
Expand Down
Loading

0 comments on commit 6849bbe

Please sign in to comment.