Skip to content

Commit

Permalink
support set and show on statement/execution timeout session variables.
Browse files Browse the repository at this point in the history
  • Loading branch information
lyang24 committed Nov 5, 2024
1 parent ac387bd commit eca4b0e
Show file tree
Hide file tree
Showing 6 changed files with 207 additions and 13 deletions.
26 changes: 26 additions & 0 deletions src/operator/src/statement.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ use query::parser::QueryStatement;
use query::QueryEngineRef;
use session::context::{Channel, QueryContextRef};
use session::table_name::table_idents_to_full_name;
use set::set_query_timeout;
use snafu::{ensure, OptionExt, ResultExt};
use sql::statements::copy::{CopyDatabase, CopyDatabaseArgument, CopyTable, CopyTableArgument};
use sql::statements::set_variables::SetVariables;
Expand Down Expand Up @@ -338,6 +339,31 @@ impl StatementExecutor {
"DATESTYLE" => set_datestyle(set_var.value, query_ctx)?,

"CLIENT_ENCODING" => validate_client_encoding(set_var)?,
// TODO: write sqlness test for query timeout variables
// once the proper channel is configured in the test infra.
// The current sqlness test channel is default to Unknown.
"MAX_EXECUTION_TIME" => match query_ctx.channel() {
Channel::Mysql => set_query_timeout(set_var.value, query_ctx)?,
Channel::Postgres => {
query_ctx.set_warning(format!("Unsupported set variable {}", var_name))
}
_ => {
return NotSupportedSnafu {
feat: format!("Unsupported set variable {}", var_name),
}
.fail()
}
},
"STATEMENT_TIMEOUT" => {
if query_ctx.channel() == Channel::Postgres {
set_query_timeout(set_var.value, query_ctx)?
} else {
return NotSupportedSnafu {
feat: format!("Unsupported set variable {}", var_name),
}
.fail();
}
}
_ => {
// for postgres, we give unknown SET statements a warning with
// success, this is prevent the SET call becoming a blocker
Expand Down
97 changes: 96 additions & 1 deletion src/operator/src/statement/set.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,19 @@
// See the License for the specific language governing permissions and
// limitations under the License.

use std::time::Duration;

use common_time::Timezone;
use session::context::Channel::Postgres;
use session::context::QueryContextRef;
use session::session_config::{PGByteaOutputValue, PGDateOrder, PGDateTimeStyle};
use snafu::{ensure, OptionExt, ResultExt};
use sql::ast::{Expr, Ident, Value};
use sql::statements::set_variables::SetVariables;

use crate::error::{InvalidConfigValueSnafu, InvalidSqlSnafu, NotSupportedSnafu, Result};
use crate::error::{
BuildRegexSnafu, InvalidConfigValueSnafu, InvalidSqlSnafu, NotSupportedSnafu, Result,
};

pub fn set_timezone(exprs: Vec<Expr>, ctx: QueryContextRef) -> Result<()> {
let tz_expr = exprs.first().context(NotSupportedSnafu {
Expand Down Expand Up @@ -177,3 +182,93 @@ pub fn set_datestyle(exprs: Vec<Expr>, ctx: QueryContextRef) -> Result<()> {
.set_pg_datetime_style(style.unwrap_or(old_style), order.unwrap_or(older_order));
Ok(())
}

pub fn set_query_timeout(exprs: Vec<Expr>, ctx: QueryContextRef) -> Result<()> {
let timeout_expr = exprs.first().context(NotSupportedSnafu {
feat: "No timeout value find in set query timeout statement",
})?;
match timeout_expr {
Expr::Value(Value::Number(timeout, _)) => {
match timeout.parse::<u64>() {
Ok(timeout) => ctx.set_query_timeout(Duration::from_millis(timeout)),
Err(_) => {
return NotSupportedSnafu {
feat: format!("Invalid timeout expr {} in set variable statement", timeout),
}
.fail()
}
}
Ok(())
}
// postgres support time units i.e. SET STATEMENT_TIMEOUT = '50ms';
Expr::Value(Value::SingleQuotedString(timeout))
| Expr::Value(Value::DoubleQuotedString(timeout)) => {
if ctx.channel() != Postgres {
return NotSupportedSnafu {
feat: format!("Invalid timeout expr {} in set variable statement", timeout),
}
.fail();
}
let timeout = parse_pg_query_timeout_input(timeout)?;
ctx.set_query_timeout(Duration::from_millis(timeout));
Ok(())
}
expr => NotSupportedSnafu {
feat: format!(
"Unsupported timeout expr {} in set variable statement",
expr
),
}
.fail(),
}
}

// support time units in ms, s, min, h, d for postgres protocol.
// https://www.postgresql.org/docs/8.4/config-setting.html#:~:text=Valid%20memory%20units%20are%20kB,%2C%20and%20d%20(days).
fn parse_pg_query_timeout_input(input: &str) -> Result<u64> {
// Regex rules:
// The string must start with a number (one or more digits).
// The number must be followed by one of the valid time units (ms, s, min, h, d).
// The string must end immediately after the unit, meaning there can be no extra
// characters or spaces after the valid time specification.
let re = regex::Regex::new(r"^(\d+)(ms|s|min|h|d)$").context(BuildRegexSnafu)?;
if let Some(captures) = re.captures(input) {
let value = captures[1].parse::<u64>().expect("regex failed");
let unit = &captures[2];

match unit {
"ms" => Ok(value),
"s" => Ok(value * 1000),
"min" => Ok(value * 60 * 1000),
"h" => Ok(value * 60 * 60 * 1000),
"d" => Ok(value * 24 * 60 * 60 * 1000),
_ => unreachable!("regex failed"),
}
} else {
NotSupportedSnafu {
feat: format!(
"Unsupported timeout expr {} in set variable statement",
input
),
}
.fail()
}
}

#[cfg(test)]
mod test {
use crate::statement::set::parse_pg_query_timeout_input;

#[test]
fn test_parse_pg_query_timeout_input() {
assert!(parse_pg_query_timeout_input("").is_err());
assert!(parse_pg_query_timeout_input(" 50 ms").is_err());
assert!(parse_pg_query_timeout_input("5s 1ms").is_err());
assert!(parse_pg_query_timeout_input("3a").is_err());
assert!(parse_pg_query_timeout_input("1.5min").is_err());

assert_eq!(12, parse_pg_query_timeout_input("12ms").unwrap());
assert_eq!(2000, parse_pg_query_timeout_input("2s").unwrap());
assert_eq!(60000, parse_pg_query_timeout_input("1min").unwrap());
}
}
19 changes: 18 additions & 1 deletion src/query/src/sql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ use datatypes::vectors::StringVector;
use object_store::ObjectStore;
use once_cell::sync::Lazy;
use regex::Regex;
use session::context::QueryContextRef;
use session::context::{Channel, QueryContextRef};
pub use show_create_table::create_table_stmt;
use snafu::{ensure, OptionExt, ResultExt};
use sql::ast::Ident;
Expand Down Expand Up @@ -651,6 +651,23 @@ pub fn show_variable(stmt: ShowVariables, query_ctx: QueryContextRef) -> Result<
let (style, order) = *query_ctx.configuration_parameter().pg_datetime_style();
format!("{}, {}", style, order)
}
"MAX_EXECUTION_TIME" => {
if query_ctx.channel() == Channel::Mysql {
query_ctx.query_timeout_as_millis().to_string()
} else {
return UnsupportedVariableSnafu { name: variable }.fail();
}
}
"STATEMENT_TIMEOUT" => {
// Add time units to postgres query timeout display.
if query_ctx.channel() == Channel::Postgres {
let mut timeout = query_ctx.query_timeout_as_millis().to_string();
timeout.push_str("ms");
timeout
} else {
return UnsupportedVariableSnafu { name: variable }.fail();
}
}
_ => return UnsupportedVariableSnafu { name: variable }.fail(),
};
let schema = Arc::new(Schema::new(vec![ColumnSchema::new(
Expand Down
17 changes: 17 additions & 0 deletions src/session/src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ use std::collections::HashMap;
use std::fmt::{Display, Formatter};
use std::net::SocketAddr;
use std::sync::{Arc, RwLock};
use std::time::Duration;

use api::v1::region::RegionRequestHeader;
use arc_swap::ArcSwap;
Expand Down Expand Up @@ -282,6 +283,22 @@ impl QueryContext {
pub fn set_warning(&self, msg: String) {
self.mutable_query_context_data.write().unwrap().warning = Some(msg);
}

pub fn query_timeout(&self) -> Option<Duration> {
self.mutable_session_data.read().unwrap().query_timeout
}

pub fn query_timeout_as_millis(&self) -> u128 {
let timeout = self.mutable_session_data.read().unwrap().query_timeout;
if let Some(t) = timeout {
return t.as_millis();
}
0
}

pub fn set_query_timeout(&self, timeout: Duration) {
self.mutable_session_data.write().unwrap().query_timeout = Some(timeout);
}
}

impl QueryContextBuilder {
Expand Down
3 changes: 3 additions & 0 deletions src/session/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ pub mod table_name;

use std::net::SocketAddr;
use std::sync::{Arc, RwLock};
use std::time::Duration;

use auth::UserInfoRef;
use common_catalog::build_db_string;
Expand Down Expand Up @@ -45,6 +46,7 @@ pub(crate) struct MutableInner {
schema: String,
user_info: UserInfoRef,
timezone: Timezone,
query_timeout: Option<Duration>,
}

impl Default for MutableInner {
Expand All @@ -53,6 +55,7 @@ impl Default for MutableInner {
schema: DEFAULT_SCHEMA_NAME.into(),
user_info: auth::userinfo_by_name(None),
timezone: get_timezone(None).clone(),
query_timeout: None,
}
}
}
Expand Down
58 changes: 47 additions & 11 deletions src/sql/src/parsers/set_var_parser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,47 +58,83 @@ mod tests {
use crate::dialect::GreptimeDbDialect;
use crate::parser::ParseOptions;

fn assert_mysql_parse_result(sql: &str) {
fn assert_mysql_parse_result(sql: &str, indent_str: &str, expr: Expr) {
let result =
ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default());
let mut stmts = result.unwrap();
assert_eq!(
stmts.pop().unwrap(),
Statement::SetVariables(SetVariables {
variable: ObjectName(vec![Ident::new("time_zone")]),
value: vec![Expr::Value(Value::SingleQuotedString("UTC".to_string()))]
variable: ObjectName(vec![Ident::new(indent_str)]),
value: vec![expr]
})
);
}

fn assert_pg_parse_result(sql: &str) {
fn assert_pg_parse_result(sql: &str, indent: &str, expr: Expr) {
let result =
ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default());
let mut stmts = result.unwrap();
assert_eq!(
stmts.pop().unwrap(),
Statement::SetVariables(SetVariables {
variable: ObjectName(vec![Ident::new("TIMEZONE")]),
value: vec![Expr::Value(Value::SingleQuotedString("UTC".to_string()))],
variable: ObjectName(vec![Ident::new(indent)]),
value: vec![expr],
})
);
}

#[test]
pub fn test_set_timezone() {
let expected_utc_expr = Expr::Value(Value::SingleQuotedString("UTC".to_string()));
// mysql style
let sql = "SET time_zone = 'UTC'";
assert_mysql_parse_result(sql);
assert_mysql_parse_result(sql, "time_zone", expected_utc_expr.clone());
// session or local style
let sql = "SET LOCAL time_zone = 'UTC'";
assert_mysql_parse_result(sql);
assert_mysql_parse_result(sql, "time_zone", expected_utc_expr.clone());
let sql = "SET SESSION time_zone = 'UTC'";
assert_mysql_parse_result(sql);
assert_mysql_parse_result(sql, "time_zone", expected_utc_expr.clone());

// postgresql style
let sql = "SET TIMEZONE TO 'UTC'";
assert_pg_parse_result(sql);
assert_pg_parse_result(sql, "TIMEZONE", expected_utc_expr.clone());
let sql = "SET TIMEZONE 'UTC'";
assert_pg_parse_result(sql);
assert_pg_parse_result(sql, "TIMEZONE", expected_utc_expr);
}

#[test]
pub fn test_set_query_timeout() {
let expected_query_timeout_expr = Expr::Value(Value::Number("5000".to_string(), false));
// mysql style
let sql = "SET MAX_EXECUTION_TIME = 5000";
assert_mysql_parse_result(
sql,
"MAX_EXECUTION_TIME",
expected_query_timeout_expr.clone(),
);
// session or local style
let sql = "SET LOCAL MAX_EXECUTION_TIME = 5000";
assert_mysql_parse_result(
sql,
"MAX_EXECUTION_TIME",
expected_query_timeout_expr.clone(),
);
let sql = "SET SESSION MAX_EXECUTION_TIME = 5000";
assert_mysql_parse_result(
sql,
"MAX_EXECUTION_TIME",
expected_query_timeout_expr.clone(),
);

// postgresql style
let sql = "SET STATEMENT_TIMEOUT = 5000";
assert_pg_parse_result(
sql,
"STATEMENT_TIMEOUT",
expected_query_timeout_expr.clone(),
);
let sql = "SET STATEMENT_TIMEOUT TO 5000";
assert_pg_parse_result(sql, "STATEMENT_TIMEOUT", expected_query_timeout_expr);
}
}

0 comments on commit eca4b0e

Please sign in to comment.