Skip to content

Commit

Permalink
refactor: assorted FlightSqlServiceClient improvements
Browse files Browse the repository at this point in the history
- **TLS config:** Do NOT alter existing method signatures if the TLS
  feature is enabled. Features should be purely additive in Rust.
  Instead use a new method to pass TLS configs. The config is now passed
  as `ClientTlsConfig` to allow more flexibility, e.g. just to use TLS
  w/o any client certs.
- **token handlng:** Allow the token to be passed in from an external
  source. The [auth spec] is super flexibility ("application-defined")
  and we cannot derive a way to determine the token in all cases. The
  current handshake-based mechanism is OK though. Also make sure the
  token is used in all relevant methods.
- **headers:** Allow users to pass in additional headers. This is
  helpful for certain applications.

[auth spec]: https://arrow.apache.org/docs/format/Flight.html#authentication
  • Loading branch information
crepererum committed Mar 2, 2023
1 parent 7852e76 commit 8a118bd
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 51 deletions.
10 changes: 6 additions & 4 deletions arrow-flight/examples/flight_sql_server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -580,12 +580,14 @@ mod tests {
let key = std::fs::read_to_string("examples/data/client1.key").unwrap();
let server_ca = std::fs::read_to_string("examples/data/ca.pem").unwrap();

let mut client = FlightSqlServiceClient::new_with_endpoint(
Identity::from_pem(cert, key),
Certificate::from_pem(&server_ca),
"localhost",
let tls_config = ClientTlsConfig::new()
.domain_name("localhost")
.ca_certificate(Certificate::from_pem(&server_ca))
.identity(Identity::from_pem(cert, key));
let mut client = FlightSqlServiceClient::new_with_tls_endpoint(
"127.0.0.1",
50051,
tls_config,
)
.await
.unwrap();
Expand Down
120 changes: 73 additions & 47 deletions arrow-flight/src/sql/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@ use base64::prelude::BASE64_STANDARD;
use base64::Engine;
use bytes::Bytes;
use std::collections::HashMap;
use std::str::FromStr;
use std::time::Duration;
use tonic::metadata::AsciiMetadataKey;

use crate::flight_service_client::FlightServiceClient;
use crate::sql::server::{CLOSE_PREPARED_STATEMENT, CREATE_PREPARED_STATEMENT};
Expand All @@ -45,15 +47,16 @@ use arrow_schema::{ArrowError, Schema, SchemaRef};
use futures::{stream, TryStreamExt};
use prost::Message;
#[cfg(feature = "tls")]
use tonic::transport::{Certificate, ClientTlsConfig, Identity};
use tonic::transport::ClientTlsConfig;
use tonic::transport::{Channel, Endpoint};
use tonic::Streaming;
use tonic::{IntoRequest, Streaming};

/// A FlightSQLServiceClient is an endpoint for retrieving or storing Arrow data
/// by FlightSQL protocol.
#[derive(Debug, Clone)]
pub struct FlightSqlServiceClient {
token: Option<String>,
headers: HashMap<String, String>,
flight_client: FlightServiceClient<Channel>,
}

Expand All @@ -62,19 +65,8 @@ pub struct FlightSqlServiceClient {
/// Github issues are welcomed.
impl FlightSqlServiceClient {
/// Creates a new FlightSql Client that connects via TCP to a server
#[cfg(not(feature = "tls"))]
pub async fn new_with_endpoint(host: &str, port: u16) -> Result<Self, ArrowError> {
let addr = format!("http://{}:{}", host, port);
let endpoint = Endpoint::new(addr)
.map_err(|_| ArrowError::IoError("Cannot create endpoint".to_string()))?
.connect_timeout(Duration::from_secs(20))
.timeout(Duration::from_secs(20))
.tcp_nodelay(true) // Disable Nagle's Algorithm since we don't want packets to wait
.tcp_keepalive(Option::Some(Duration::from_secs(3600)))
.http2_keep_alive_interval(Duration::from_secs(300))
.keep_alive_timeout(Duration::from_secs(20))
.keep_alive_while_idle(true);

let endpoint = Self::endpoint(host, port)?;
let channel = endpoint.connect().await.map_err(|e| {
ArrowError::IoError(format!("Cannot connect to endpoint: {}", e))
})?;
Expand All @@ -83,13 +75,23 @@ impl FlightSqlServiceClient {

/// Creates a new HTTPs FlightSql Client that connects via TCP to a server
#[cfg(feature = "tls")]
pub async fn new_with_endpoint(
client_ident: Identity,
server_ca: Certificate,
domain: &str,
pub async fn new_with_tls_endpoint(
host: &str,
port: u16,
tls_config: ClientTlsConfig,
) -> Result<Self, ArrowError> {
let endpoint = Self::endpoint(host, port)?;
let endpoint = endpoint
.tls_config(tls_config)
.map_err(|_| ArrowError::IoError("Cannot create endpoint".to_string()))?;

let channel = endpoint.connect().await.map_err(|e| {
ArrowError::IoError(format!("Cannot connect to endpoint: {e}"))
})?;
Ok(Self::new(channel))
}

fn endpoint(host: &str, port: u16) -> Result<Endpoint, ArrowError> {
let addr = format!("https://{host}:{port}");

let endpoint = Endpoint::new(addr)
Expand All @@ -102,19 +104,7 @@ impl FlightSqlServiceClient {
.keep_alive_timeout(Duration::from_secs(20))
.keep_alive_while_idle(true);

let tls_config = ClientTlsConfig::new()
.domain_name(domain)
.ca_certificate(server_ca)
.identity(client_ident);

let endpoint = endpoint
.tls_config(tls_config)
.map_err(|_| ArrowError::IoError("Cannot create endpoint".to_string()))?;

let channel = endpoint.connect().await.map_err(|e| {
ArrowError::IoError(format!("Cannot connect to endpoint: {e}"))
})?;
Ok(Self::new(channel))
Ok(endpoint)
}

/// Creates a new FlightSql client that connects to a server over an arbitrary tonic `Channel`
Expand All @@ -123,6 +113,7 @@ impl FlightSqlServiceClient {
FlightSqlServiceClient {
token: None,
flight_client,
headers: HashMap::default(),
}
}

Expand All @@ -141,14 +132,27 @@ impl FlightSqlServiceClient {
self.flight_client
}

/// Set auth token to the given value.
pub fn set_token(&mut self, token: String) {
self.token = Some(token);
}

/// Set header value.
pub fn set_header(&mut self, key: impl Into<String>, value: impl Into<String>) {
let key: String = key.into();
let value: String = value.into();
self.headers.insert(key, value);
}

async fn get_flight_info_for_command<M: ProstMessageExt>(
&mut self,
cmd: M,
) -> Result<FlightInfo, ArrowError> {
let descriptor = FlightDescriptor::new_cmd(cmd.as_any().encode_to_vec());
let req = self.set_request_headers(descriptor.into_request())?;
let fi = self
.flight_client
.get_flight_info(descriptor)
.get_flight_info(req)
.await
.map_err(status_to_arrow_error)?
.into_inner();
Expand Down Expand Up @@ -178,6 +182,7 @@ impl FlightSqlServiceClient {
.parse()
.map_err(|_| ArrowError::ParseError("Cannot parse header".to_string()))?;
req.metadata_mut().insert("authorization", val);
let req = self.set_request_headers(req)?;
let resp = self
.flight_client
.handshake(req)
Expand All @@ -199,25 +204,29 @@ impl FlightSqlServiceClient {
ArrowError::ParseError("Can't collect responses".to_string())
})?;
let resp = match responses.as_slice() {
[resp] => resp,
[] => Err(ArrowError::ParseError("No handshake response".to_string()))?,
[resp] => resp.payload.clone(),
[] => Bytes::new(),
_ => Err(ArrowError::ParseError(
"Multiple handshake responses".to_string(),
))?,
};
Ok(resp.payload.clone())
Ok(resp)
}

/// Execute a update query on the server, and return the number of records affected
pub async fn execute_update(&mut self, query: String) -> Result<i64, ArrowError> {
let cmd = CommandStatementUpdate { query };
let descriptor = FlightDescriptor::new_cmd(cmd.as_any().encode_to_vec());
let mut result = self
.flight_client
.do_put(stream::iter(vec![FlightData {
let req = self.set_request_headers(
stream::iter(vec![FlightData {
flight_descriptor: Some(descriptor),
..Default::default()
}]))
}])
.into_request(),
)?;
let mut result = self
.flight_client
.do_put(req)
.await
.map_err(status_to_arrow_error)?
.into_inner();
Expand Down Expand Up @@ -251,9 +260,10 @@ impl FlightSqlServiceClient {
&mut self,
ticket: Ticket,
) -> Result<Streaming<FlightData>, ArrowError> {
let req = self.set_request_headers(ticket.into_request())?;
Ok(self
.flight_client
.do_get(ticket)
.do_get(req)
.await
.map_err(status_to_arrow_error)?
.into_inner())
Expand Down Expand Up @@ -329,13 +339,7 @@ impl FlightSqlServiceClient {
r#type: CREATE_PREPARED_STATEMENT.to_string(),
body: cmd.as_any().encode_to_vec().into(),
};
let mut req = tonic::Request::new(action);
if let Some(token) = &self.token {
let val = format!("Bearer {token}").parse().map_err(|_| {
ArrowError::IoError("Statement already closed.".to_string())
})?;
req.metadata_mut().insert("authorization", val);
}
let req = self.set_request_headers(action.into_request())?;
let mut result = self
.flight_client
.do_action(req)
Expand Down Expand Up @@ -369,6 +373,28 @@ impl FlightSqlServiceClient {
pub async fn close(&mut self) -> Result<(), ArrowError> {
Ok(())
}

fn set_request_headers<T>(
&self,
mut req: tonic::Request<T>,
) -> Result<tonic::Request<T>, ArrowError> {
for (k, v) in &self.headers {
let k = AsciiMetadataKey::from_str(k.as_str()).map_err(|e| {
ArrowError::IoError(format!("Cannot convert header key \"{k}\": {e}"))
})?;
let v = v.parse().map_err(|e| {
ArrowError::IoError(format!("Cannot convert header value \"{v}\": {e}"))
})?;
req.metadata_mut().insert(k, v);
}
if let Some(token) = &self.token {
let val = format!("Bearer {token}").parse().map_err(|e| {
ArrowError::IoError(format!("Cannot convert token to header value: {e}"))
})?;
req.metadata_mut().insert("authorization", val);
}
Ok(req)
}
}

/// A PreparedStatement
Expand Down

0 comments on commit 8a118bd

Please sign in to comment.