Skip to content

Commit

Permalink
feat: support dsn arg tls_ca_file (#161)
Browse files Browse the repository at this point in the history
  • Loading branch information
everpcpc authored Jul 13, 2023
1 parent 5ce8ccb commit 7a79b27
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 1 deletion.
15 changes: 15 additions & 0 deletions core/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,8 @@ pub struct APIClient {
max_rows_in_buffer: Option<i64>,
max_rows_per_page: Option<i64>,

tls_ca_file: Option<String>,

presigned_url_disabled: bool,
}

Expand Down Expand Up @@ -144,6 +146,9 @@ impl APIClient {
scheme = "http";
}
}
"tls_ca_file" => {
client.tls_ca_file = Some(v.to_string());
}
_ => {
session_settings.insert(k.to_string(), v.to_string());
}
Expand All @@ -157,6 +162,15 @@ impl APIClient {
_ => unreachable!(),
},
};

#[cfg(any(feature = "rustls", feature = "native-tls"))]
if scheme == "https" {
if let Some(ref ca_file) = client.tls_ca_file {
let cert_pem = std::fs::read(ca_file)?;
let cert = reqwest::Certificate::from_pem(&cert_pem)?;
client.cli = HttpClient::builder().add_root_certificate(cert).build()?;
}
}
client.endpoint = Url::parse(&format!("{}://{}:{}", scheme, client.host, client.port))?;
client.session_settings = Arc::new(Mutex::new(session_settings));

Expand Down Expand Up @@ -486,6 +500,7 @@ impl Default for APIClient {
wait_time_secs: None,
max_rows_in_buffer: None,
max_rows_per_page: None,
tls_ca_file: None,
presigned_url_disabled: false,
}
}
Expand Down
8 changes: 8 additions & 0 deletions core/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ pub enum Error {
Parsing(String),
BadArgument(String),
Request(String),
IO(String),
InvalidResponse(response::QueryError),
InvalidPage(response::QueryError),
}
Expand All @@ -29,6 +30,7 @@ impl std::fmt::Display for Error {
Error::Parsing(msg) => write!(f, "ParsingError: {msg}"),
Error::BadArgument(msg) => write!(f, "BadArgument: {msg}"),
Error::Request(msg) => write!(f, "RequestError: {msg}"),
Error::IO(msg) => write!(f, "IOError: {msg}"),
Error::InvalidResponse(e) => {
write!(f, "ResponseError with {}: {}", e.code, e.message)
}
Expand Down Expand Up @@ -70,3 +72,9 @@ impl From<reqwest::Error> for Error {
Error::Request(e.to_string())
}
}

impl From<std::io::Error> for Error {
fn from(e: std::io::Error) -> Self {
Error::IO(e.to_string())
}
}
8 changes: 8 additions & 0 deletions driver/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ pub enum Error {
Parsing(String),
Protocol(String),
Transport(String),
IO(String),
BadArgument(String),
InvalidResponse(String),
Api(databend_client::error::Error),
Expand All @@ -53,6 +54,7 @@ impl std::fmt::Display for Error {
Error::Parsing(msg) => write!(f, "ParseError: {}", msg),
Error::Protocol(msg) => write!(f, "ProtocolError: {}", msg),
Error::Transport(msg) => write!(f, "TransportError: {}", msg),
Error::IO(msg) => write!(f, "IOError: {}", msg),

Error::BadArgument(msg) => write!(f, "BadArgument: {}", msg),
Error::InvalidResponse(msg) => write!(f, "ResponseError: {}", msg),
Expand Down Expand Up @@ -116,6 +118,12 @@ impl From<chrono::ParseError> for Error {
}
}

impl From<std::io::Error> for Error {
fn from(e: std::io::Error) -> Self {
Error::IO(e.to_string())
}
}

#[cfg(feature = "flight-sql")]
impl From<tonic::Status> for Error {
fn from(e: tonic::Status) -> Self {
Expand Down
13 changes: 12 additions & 1 deletion driver/src/flight_sql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -148,8 +148,16 @@ impl FlightSQLConnection {
.http2_keep_alive_interval(args.http2_keep_alive_interval)
.keep_alive_timeout(args.keep_alive_timeout)
.keep_alive_while_idle(args.keep_alive_while_idle);
#[cfg(any(feature = "rustls", feature = "native-tls"))]
if args.tls {
let tls_config = ClientTlsConfig::new();
let tls_config = match args.tls_ca_file {
None => ClientTlsConfig::new(),
Some(ref ca_file) => {
let pem = std::fs::read(ca_file)?;
let cert = tonic::transport::Certificate::from_pem(pem);
ClientTlsConfig::new().ca_certificate(cert)
}
};
endpoint = endpoint.tls_config(tls_config)?;
}
Ok((args, endpoint))
Expand All @@ -167,6 +175,7 @@ struct Args {
tenant: Option<String>,
warehouse: Option<String>,
tls: bool,
tls_ca_file: Option<String>,
connect_timeout: Duration,
query_timeout: Duration,
tcp_nodelay: bool,
Expand All @@ -187,6 +196,7 @@ impl Default for Args {
tenant: None,
warehouse: None,
tls: true,
tls_ca_file: None,
user: "root".to_string(),
password: "".to_string(),
connect_timeout: Duration::from_secs(20),
Expand Down Expand Up @@ -214,6 +224,7 @@ impl Args {
args.tls = false;
}
}
"tls_ca_file" => args.tls_ca_file = Some(v.to_string()),
"connect_timeout" => args.connect_timeout = Duration::from_secs(v.parse()?),
"query_timeout" => args.query_timeout = Duration::from_secs(v.parse()?),
"tcp_nodelay" => args.tcp_nodelay = v.parse()?,
Expand Down

0 comments on commit 7a79b27

Please sign in to comment.