Skip to content

Commit

Permalink
Add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
davidm-db committed Sep 4, 2024
1 parent 3a09416 commit b5b6b25
Show file tree
Hide file tree
Showing 3 changed files with 426 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -668,6 +668,184 @@ class SqlScriptingParserSuite extends SparkFunSuite with SQLHelper {
assert(whileStmt.label.contains("lbl"))
}

test("leave compound block") {
val sqlScriptText =
"""
|lbl: BEGIN
| SELECT 1;
| LEAVE lbl;
|END""".stripMargin
val tree = parseScript(sqlScriptText)
assert(tree.collection.length == 2)
assert(tree.collection.head.isInstanceOf[SingleStatement])
assert(tree.collection(1).isInstanceOf[LeaveStatement])
}

test("leave while loop") {
val sqlScriptText =
"""
|BEGIN
| lbl: WHILE 1 = 1 DO
| SELECT 1;
| LEAVE lbl;
| END WHILE;
|END""".stripMargin
val tree = parseScript(sqlScriptText)
assert(tree.collection.length == 1)
assert(tree.collection.head.isInstanceOf[WhileStatement])

val whileStmt = tree.collection.head.asInstanceOf[WhileStatement]
assert(whileStmt.condition.isInstanceOf[SingleStatement])
assert(whileStmt.condition.getText == "1 = 1")

assert(whileStmt.body.isInstanceOf[CompoundBody])
assert(whileStmt.body.collection.length == 2)

assert(whileStmt.body.collection.head.isInstanceOf[SingleStatement])
assert(whileStmt.body.collection.head.asInstanceOf[SingleStatement].getText == "SELECT 1")

assert(whileStmt.body.collection(1).isInstanceOf[LeaveStatement])
assert(whileStmt.body.collection(1).asInstanceOf[LeaveStatement].label == "lbl")
}

test ("iterate compound block - should fail") {
val sqlScriptText =
"""
|lbl: BEGIN
| SELECT 1;
| ITERATE lbl;
|END""".stripMargin
checkError(
exception = intercept[SqlScriptingException] {
parseScript(sqlScriptText)
},
errorClass = "INVALID_LABEL_USAGE_IN_STATEMENT",
parameters = Map("labelName" -> "LBL", "statementType" -> "ITERATE"))
}

test("iterate while loop") {
val sqlScriptText =
"""
|BEGIN
| lbl: WHILE 1 = 1 DO
| SELECT 1;
| ITERATE lbl;
| END WHILE;
|END""".stripMargin
val tree = parseScript(sqlScriptText)
assert(tree.collection.length == 1)
assert(tree.collection.head.isInstanceOf[WhileStatement])

val whileStmt = tree.collection.head.asInstanceOf[WhileStatement]
assert(whileStmt.condition.isInstanceOf[SingleStatement])
assert(whileStmt.condition.getText == "1 = 1")

assert(whileStmt.body.isInstanceOf[CompoundBody])
assert(whileStmt.body.collection.length == 2)

assert(whileStmt.body.collection.head.isInstanceOf[SingleStatement])
assert(whileStmt.body.collection.head.asInstanceOf[SingleStatement].getText == "SELECT 1")

assert(whileStmt.body.collection(1).isInstanceOf[IterateStatement])
assert(whileStmt.body.collection(1).asInstanceOf[IterateStatement].label == "lbl")
}

test("leave with wrong label - should fail") {
val sqlScriptText =
"""
|lbl: BEGIN
| SELECT 1;
| LEAVE randomlbl;
|END""".stripMargin
checkError(
exception = intercept[SqlScriptingException] {
parseScript(sqlScriptText)
},
errorClass = "INVALID_LABEL_USAGE_IN_STATEMENT",
parameters = Map("labelName" -> "RANDOMLBL", "statementType" -> "LEAVE"))
}

test("iterate with wrong label - should fail") {
val sqlScriptText =
"""
|lbl: BEGIN
| SELECT 1;
| ITERATE randomlbl;
|END""".stripMargin
checkError(
exception = intercept[SqlScriptingException] {
parseScript(sqlScriptText)
},
errorClass = "INVALID_LABEL_USAGE_IN_STATEMENT",
parameters = Map("labelName" -> "RANDOMLBL", "statementType" -> "ITERATE"))
}

test("leave outer loop from nested while loop") {
val sqlScriptText =
"""
|BEGIN
| lbl: WHILE 1 = 1 DO
| lbl2: WHILE 2 = 2 DO
| SELECT 1;
| LEAVE lbl;
| END WHILE;
| END WHILE;
|END""".stripMargin
val tree = parseScript(sqlScriptText)
assert(tree.collection.length == 1)
assert(tree.collection.head.isInstanceOf[WhileStatement])

val whileStmt = tree.collection.head.asInstanceOf[WhileStatement]
assert(whileStmt.condition.isInstanceOf[SingleStatement])
assert(whileStmt.condition.getText == "1 = 1")

assert(whileStmt.body.isInstanceOf[CompoundBody])
assert(whileStmt.body.collection.length == 1)

val nestedWhileStmt = whileStmt.body.collection.head.asInstanceOf[WhileStatement]
assert(nestedWhileStmt.condition.isInstanceOf[SingleStatement])
assert(nestedWhileStmt.condition.getText == "2 = 2")

assert(nestedWhileStmt.body.collection.head.isInstanceOf[SingleStatement])
assert(nestedWhileStmt.body.collection.head.asInstanceOf[SingleStatement].getText == "SELECT 1")

assert(nestedWhileStmt.body.collection(1).isInstanceOf[LeaveStatement])
assert(nestedWhileStmt.body.collection(1).asInstanceOf[LeaveStatement].label == "lbl")
}

test("iterate outer loop from nested while loop") {
val sqlScriptText =
"""
|BEGIN
| lbl: WHILE 1 = 1 DO
| lbl2: WHILE 2 = 2 DO
| SELECT 1;
| ITERATE lbl;
| END WHILE;
| END WHILE;
|END""".stripMargin
val tree = parseScript(sqlScriptText)
assert(tree.collection.length == 1)
assert(tree.collection.head.isInstanceOf[WhileStatement])

val whileStmt = tree.collection.head.asInstanceOf[WhileStatement]
assert(whileStmt.condition.isInstanceOf[SingleStatement])
assert(whileStmt.condition.getText == "1 = 1")

assert(whileStmt.body.isInstanceOf[CompoundBody])
assert(whileStmt.body.collection.length == 1)

val nestedWhileStmt = whileStmt.body.collection.head.asInstanceOf[WhileStatement]
assert(nestedWhileStmt.condition.isInstanceOf[SingleStatement])
assert(nestedWhileStmt.condition.getText == "2 = 2")

assert(nestedWhileStmt.body.collection.head.isInstanceOf[SingleStatement])
assert(nestedWhileStmt.body.collection.head.asInstanceOf[SingleStatement].getText == "SELECT 1")

assert(nestedWhileStmt.body.collection(1).isInstanceOf[IterateStatement])
assert(nestedWhileStmt.body.collection(1).asInstanceOf[IterateStatement].label == "lbl")
}

// Helper methods
def cleanupStatementString(statementStr: String): String = {
statementStr
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,10 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite with SharedSparkSessi
override def reset(): Unit = ()
}

case class TestLeaveStatement(labelText: String) extends LeaveStatementExec(labelText)

case class TestIterateStatement(labelText: String) extends IterateStatementExec(labelText)

case class TestIfElseCondition(condVal: Boolean, description: String)
extends SingleStatementExec(
parsedPlan = Project(Seq(Alias(Literal(condVal), description)()), OneRowRelation()),
Expand All @@ -54,8 +58,9 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite with SharedSparkSessi

case class TestWhile(
condition: TestWhileCondition,
body: CompoundBodyExec)
extends WhileStatementExec(condition, body, None, spark) {
body: CompoundBodyExec,
label: Option[String] = None)
extends WhileStatementExec(condition, body, label, spark) {

private var callCount: Int = 0

Expand All @@ -77,6 +82,8 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite with SharedSparkSessi
case TestLeafStatement(testVal) => testVal
case TestIfElseCondition(_, description) => description
case TestWhileCondition(_, _, description) => description
case TestLeaveStatement(label) => label
case TestIterateStatement(label) => label
case _ => fail("Unexpected statement type")
}

Expand Down Expand Up @@ -314,4 +321,100 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite with SharedSparkSessi
"con2", "body1", "con2", "con1"))
}

test("leave compound block") {
val iter = new CompoundBodyExec(
statements = Seq(
TestLeafStatement("one"),
TestLeaveStatement("lbl")
),
label = Some("lbl")
).getTreeIterator
val statements = iter.map(extractStatementValue).toSeq
assert(statements === Seq("one", "lbl"))
}

test("leave while loop") {
val iter = new CompoundBodyExec(
statements = Seq(
TestWhile(
condition = TestWhileCondition(condVal = true, reps = 2, description = "con1"),
body = new CompoundBodyExec(Seq(
TestLeafStatement("body1"),
TestLeaveStatement("lbl"))
),
label = Some("lbl")
)
)
).getTreeIterator
val statements = iter.map(extractStatementValue).toSeq
assert(statements === Seq("con1", "body1", "lbl"))
}

test("iterate while loop") {
val iter = new CompoundBodyExec(
statements = Seq(
TestWhile(
condition = TestWhileCondition(condVal = true, reps = 2, description = "con1"),
body = new CompoundBodyExec(Seq(
TestLeafStatement("body1"),
TestIterateStatement("lbl"),
TestLeafStatement("body2"))
),
label = Some("lbl")
)
)
).getTreeIterator
val statements = iter.map(extractStatementValue).toSeq
assert(statements === Seq("con1", "body1", "lbl", "con1", "body1", "lbl", "con1"))
}

test("leave outer loop from nested while loop") {
val iter = new CompoundBodyExec(
statements = Seq(
TestWhile(
condition = TestWhileCondition(condVal = true, reps = 2, description = "con1"),
body = new CompoundBodyExec(Seq(
TestWhile(
condition = TestWhileCondition(condVal = true, reps = 2, description = "con2"),
body = new CompoundBodyExec(Seq(
TestLeafStatement("body1"),
TestLeaveStatement("lbl"))
),
label = Some("lbl2")
)
)),
label = Some("lbl")
)
)
).getTreeIterator
val statements = iter.map(extractStatementValue).toSeq
assert(statements === Seq("con1", "con2", "body1", "lbl"))
}

test("iterate outer loop from nested while loop") {
val iter = new CompoundBodyExec(
statements = Seq(
TestWhile(
condition = TestWhileCondition(condVal = true, reps = 2, description = "con1"),
body = new CompoundBodyExec(Seq(
TestWhile(
condition = TestWhileCondition(condVal = true, reps = 2, description = "con2"),
body = new CompoundBodyExec(Seq(
TestLeafStatement("body1"),
TestIterateStatement("lbl"),
TestLeafStatement("body2"))
),
label = Some("lbl2")
)
)),
label = Some("lbl")
)
)
).getTreeIterator
val statements = iter.map(extractStatementValue).toSeq
assert(statements === Seq(
"con1", "con2", "body1", "lbl",
"con1", "con2", "body1", "lbl",
"con1"))
}
}
Loading

0 comments on commit b5b6b25

Please sign in to comment.