Skip to content

Commit

Permalink
feat: add label for skipif and onlyif conditions (#179)
Browse files Browse the repository at this point in the history
* use main.rs

* add label for condition

* update CHANGELOG

* fix docs

* Apply suggestions from code review

Co-authored-by: xxchan <xxchan22f@gmail.com>
Signed-off-by: Runji Wang <wangrunji0408@163.com>

---------

Signed-off-by: Runji Wang <wangrunji0408@163.com>
Co-authored-by: xxchan <xxchan22f@gmail.com>
  • Loading branch information
wangrunji0408 and xxchan authored Jun 9, 2023
1 parent 27eb9f5 commit 4685291
Show file tree
Hide file tree
Showing 9 changed files with 91 additions and 46 deletions.
7 changes: 7 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,13 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## [0.14.0] - 2023-06-08

* We enhanced how `skipif` and `onlyif` works. Previously it checks against `DB::engine_name()`, and `sqllogictest-bin` didn't implement it.
- (parser) A minor **breaking change**: Change the field names of `Condition:: OnlyIf/SkipIf`.
- (runner) Add `Runner::add_label`. Now multiple labels are supported ( `DB::engine_name()` is still included). The condition evaluates to true if *any* of the provided labels match the `skipif/onlyif <lable>`.
- (bin) Add `--label` option to specify custom labels.

## [0.13.2] - 2023-03-24

* `Runner::update_test_file` properly escapes regex special characters.
Expand Down
6 changes: 3 additions & 3 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
members = ["examples/*", "sqllogictest", "sqllogictest-bin", "sqllogictest-engines", "tests"]

[workspace.package]
version = "0.13.2"
version = "0.14.0"
edition = "2021"
homepage = "https://github.com/risinglightdb/sqllogictest-rs"
keywords = ["sql", "database", "parser", "cli"]
Expand Down
8 changes: 6 additions & 2 deletions sqllogictest-bin/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@ license = { workspace = true }
repository = { workspace = true }
description = "Sqllogictest CLI."

[[bin]]
name = "sqllogictest"
path = "src/main.rs"

[dependencies]
anyhow = { version = "1" }
async-trait = "0.1"
Expand All @@ -20,8 +24,8 @@ glob = "0.3"
itertools = "0.10"
quick-junit = { version = "0.2" }
rand = "0.8"
sqllogictest = { path = "../sqllogictest", version = "0.13" }
sqllogictest-engines = { path = "../sqllogictest-engines", version = "0.13" }
sqllogictest = { path = "../sqllogictest", version = "0.14" }
sqllogictest-engines = { path = "../sqllogictest-engines", version = "0.14" }
tokio = { version = "1", features = [
"rt",
"rt-multi-thread",
Expand Down
9 changes: 0 additions & 9 deletions sqllogictest-bin/src/bin/sqllogictest.rs

This file was deleted.

52 changes: 46 additions & 6 deletions sqllogictest-bin/src/lib.rs → sqllogictest-bin/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,15 @@ struct Opt {
/// Reformats the test files.
#[clap(long)]
format: bool,

/// Add a label for conditions.
///
/// Records with `skipif label` will be skipped if the label is present.
/// Records with `onlyif label` will be executed only if the label is present.
///
/// The engine name is a label by default.
#[clap(long = "label")]
labels: Vec<String>,
}

/// Connection configuration.
Expand All @@ -119,7 +128,8 @@ impl DBConfig {
}
}

pub async fn main_okk() -> Result<()> {
#[tokio::main]
pub async fn main() -> Result<()> {
env_logger::init();

let Opt {
Expand All @@ -137,6 +147,7 @@ pub async fn main_okk() -> Result<()> {
options,
r#override,
format,
labels,
} = Opt::parse();

if host.len() != port.len() {
Expand Down Expand Up @@ -204,9 +215,26 @@ pub async fn main_okk() -> Result<()> {
test_suite.set_timestamp(Local::now());

let result = if let Some(jobs) = jobs {
run_parallel(jobs, &mut test_suite, files, &engine, config, junit.clone()).await
run_parallel(
jobs,
&mut test_suite,
files,
&engine,
config,
&labels,
junit.clone(),
)
.await
} else {
run_serial(&mut test_suite, files, &engine, config, junit.clone()).await
run_serial(
&mut test_suite,
files,
&engine,
config,
&labels,
junit.clone(),
)
.await
};

report.add_test_suite(test_suite);
Expand All @@ -224,6 +252,7 @@ async fn run_parallel(
files: Vec<PathBuf>,
engine: &EngineConfig,
config: DBConfig,
labels: &[String],
junit: Option<String>,
) -> Result<()> {
let mut create_databases = BTreeMap::new();
Expand Down Expand Up @@ -257,10 +286,13 @@ async fn run_parallel(
config.db = db_name;
let file = filename.to_string_lossy().to_string();
let engine = engine.clone();
let labels = labels.to_vec();
async move {
let (buf, res) = tokio::spawn(async move {
let mut buf = vec![];
let res = connect_and_run_test_file(&mut buf, filename, &engine, config).await;
let res =
connect_and_run_test_file(&mut buf, filename, &engine, config, &labels)
.await;
(buf, res)
})
.await
Expand Down Expand Up @@ -331,13 +363,17 @@ async fn run_serial(
files: Vec<PathBuf>,
engine: &EngineConfig,
config: DBConfig,
labels: &[String],
junit: Option<String>,
) -> Result<()> {
let mut failed_case = vec![];

for file in files {
let engine = engines::connect(engine, &config).await?;
let runner = Runner::new(engine);
let mut runner = Runner::new(engine);
for label in labels {
runner.add_label(label);
}

let filename = file.to_string_lossy().to_string();
let test_case_name = filename.replace(['/', ' ', '.', '-'], "_");
Expand Down Expand Up @@ -405,9 +441,13 @@ async fn connect_and_run_test_file(
filename: PathBuf,
engine: &EngineConfig,
config: DBConfig,
labels: &[String],
) -> Result<Duration> {
let engine = engines::connect(engine, &config).await?;
let runner = Runner::new(engine);
let mut runner = Runner::new(engine);
for label in labels {
runner.add_label(label);
}
let result = run_test_file(out, runner, filename).await?;

Ok(result)
Expand Down
2 changes: 1 addition & 1 deletion sqllogictest-engines/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ postgres-types = { version = "0.2.3", features = ["derive", "with-chrono-0_4"] }
rust_decimal = { version = "1.7.0", features = ["tokio-pg"] }
serde = { version = "1", features = ["derive"] }
serde_json = "1"
sqllogictest = { path = "../sqllogictest", version = "0.13" }
sqllogictest = { path = "../sqllogictest", version = "0.14" }
thiserror = "1"
tokio = { version = "1", features = [
"rt",
Expand Down
36 changes: 16 additions & 20 deletions sqllogictest/src/parser.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
//! Sqllogictest parser.

use std::collections::HashSet;
use std::fmt;
use std::path::Path;
use std::sync::Arc;
Expand Down Expand Up @@ -244,12 +246,8 @@ impl<T: ColumnType> std::fmt::Display for Record<T> {
Control::SortMode(m) => write!(f, "control sortmode {}", m.as_str()),
},
Record::Condition(cond) => match cond {
Condition::OnlyIf { engine_name } => {
write!(f, "onlyif {engine_name}")
}
Condition::SkipIf { engine_name } => {
write!(f, "skipif {engine_name}")
}
Condition::OnlyIf { label } => write!(f, "onlyif {label}"),
Condition::SkipIf { label } => write!(f, "skipif {label}"),
},
Record::HashThreshold { loc: _, threshold } => {
write!(f, "hash-threshold {threshold}")
Expand Down Expand Up @@ -287,20 +285,18 @@ pub enum Injected {
/// The condition to run a query.
#[derive(Debug, PartialEq, Eq, Clone)]
pub enum Condition {
/// The statement or query is skipped if an `onlyif` record for a different database engine is
/// seen.
OnlyIf { engine_name: String },
/// The statement or query is not evaluated if a `skipif` record for the target database engine
/// is seen in the prefix.
SkipIf { engine_name: String },
/// The statement or query is evaluated only if the label is seen.
OnlyIf { label: String },
/// The statement or query is not evaluated if the label is seen.
SkipIf { label: String },
}

impl Condition {
/// Evaluate condition on given `targe_name`, returns whether to skip this record.
pub fn should_skip(&self, target_name: &str) -> bool {
/// Evaluate condition on given `label`, returns whether to skip this record.
pub(crate) fn should_skip(&self, labels: &HashSet<String>) -> bool {
match self {
Condition::OnlyIf { engine_name } => engine_name != target_name,
Condition::SkipIf { engine_name } => engine_name == target_name,
Condition::OnlyIf { label } => !labels.contains(label),
Condition::SkipIf { label } => labels.contains(label),
}
}
}
Expand Down Expand Up @@ -457,16 +453,16 @@ fn parse_inner<T: ColumnType>(loc: &Location, script: &str) -> Result<Vec<Record
loc,
});
}
["skipif", engine_name] => {
["skipif", label] => {
let cond = Condition::SkipIf {
engine_name: engine_name.to_string(),
label: label.to_string(),
};
conditions.push(cond.clone());
records.push(Record::Condition(cond));
}
["onlyif", engine_name] => {
["onlyif", label] => {
let cond = Condition::OnlyIf {
engine_name: engine_name.to_string(),
label: label.to_string(),
};
conditions.push(cond.clone());
records.push(Record::Condition(cond));
Expand Down
15 changes: 11 additions & 4 deletions sqllogictest/src/runner.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
//! Sqllogictest runner.

use std::collections::HashSet;
use std::fmt::{Debug, Display};
use std::path::Path;
use std::sync::Arc;
Expand Down Expand Up @@ -459,21 +460,29 @@ pub struct Runner<D: AsyncDB> {
sort_mode: Option<SortMode>,
/// 0 means never hashing
hash_threshold: usize,
/// Labels for condition `skipif` and `onlyif`.
labels: HashSet<String>,
}

impl<D: AsyncDB> Runner<D> {
/// Create a new test runner on the database.
pub fn new(db: D) -> Self {
Runner {
db,
validator: default_validator,
column_type_validator: default_column_validator,
testdir: None,
sort_mode: None,
hash_threshold: 0,
labels: [db.engine_name().to_string()].into_iter().collect(),
db,
}
}

/// Add a label for condition `skipif` and `onlyif`.
pub fn add_label(&mut self, label: &str) {
self.labels.insert(label.to_string());
}

/// Replace the pattern `__TEST_DIR__` in SQL with a temporary directory path.
///
/// This feature is useful in those tests where data will be written to local
Expand Down Expand Up @@ -913,9 +922,7 @@ impl<D: AsyncDB> Runner<D> {

/// Returns whether we should skip this record, according to given `conditions`.
fn should_skip(&self, conditions: &[Condition]) -> bool {
conditions
.iter()
.any(|c| c.should_skip(self.db.engine_name()))
conditions.iter().any(|c| c.should_skip(&self.labels))
}

/// Updates a test file with the output produced by a Database. It is an utility function
Expand Down

0 comments on commit 4685291

Please sign in to comment.