Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support uncorrelated scalar subquery after comparison operator in WHERE clause #14148

Open
wants to merge 16 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@
import java.sql.Statement;
import java.util.Arrays;

import static org.apache.iotdb.db.it.utils.TestUtils.tableAssertTestFail;
import static org.apache.iotdb.db.it.utils.TestUtils.tableResultSetEqualTest;
import static org.junit.Assert.fail;

Expand Down Expand Up @@ -154,14 +153,6 @@ public class IoTDBMultiIDsWithAttributesTableIT {
String[] retArray;
static String sql;

// public static void main(String[] args) {
// for (String[] sqlList : Arrays.asList(sql4, sql5)) {
// for (String sql : sqlList) {
// System.out.println(sql);
// }
// }
// }

@BeforeClass
public static void setUp() throws Exception {
EnvFactory.getEnv().getConfig().getDataNodeCommonConfig().setSortBufferSize(1024 * 1024L);
Expand Down Expand Up @@ -1694,17 +1685,21 @@ public void fourTableJoinTest() {
+ "order by s.student_id, t.teacher_id, c.course_id,g.grade_id";
tableResultSetEqualTest(sql, expectedHeader, retArray, DATABASE_NAME);

expectedHeader = new String[] {"region", "name", "teacher_id", "course_name", "score"};

retArray =
new String[] {
"haidian,Lucy,1005,数学,99,",
};
sql =
"select s.region, s.name,"
+ " t.teacher_id,"
+ " c.course_name,"
+ " g.score "
+ "from students s, teachers t, courses c, grades g "
+ "where s.time=c.time and c.time=g.time";
tableAssertTestFail(
sql,
"701: Cross join is not supported in current version, each table must have at least one equiJoinClause",
DATABASE_NAME);
+ "where s.time=c.time and c.time=g.time and t.teacher_id = 1005 limit 1";

tableResultSetEqualTest(sql, expectedHeader, retArray, DATABASE_NAME);
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3713,7 +3713,7 @@ public void modeTest() {
public void exceptionTest() {
tableAssertTestFail(
"select s1 from table1 where s2 in (select s2 from table1)",
"701: Only TableSubquery is supported now",
"Not a valid IR expression",
DATABASE_NAME);

tableAssertTestFail(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,283 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

package org.apache.iotdb.relational.it.query.recent.subquery;

import org.apache.iotdb.it.env.EnvFactory;
import org.apache.iotdb.it.framework.IoTDBTestRunner;
import org.apache.iotdb.itbase.category.TableClusterIT;
import org.apache.iotdb.itbase.category.TableLocalStandaloneIT;

import org.junit.AfterClass;
import org.junit.BeforeClass;
import org.junit.Test;
import org.junit.experimental.categories.Category;
import org.junit.runner.RunWith;

import static org.apache.iotdb.db.it.utils.TestUtils.prepareTableData;
import static org.apache.iotdb.db.it.utils.TestUtils.tableAssertTestFail;
import static org.apache.iotdb.db.it.utils.TestUtils.tableResultSetEqualTest;
import static org.apache.iotdb.relational.it.query.recent.subquery.SubqueryDataSetUtils.CREATE_SQLS;
import static org.apache.iotdb.relational.it.query.recent.subquery.SubqueryDataSetUtils.DATABASE_NAME;
import static org.apache.iotdb.relational.it.query.recent.subquery.SubqueryDataSetUtils.NUMERIC_MEASUREMENTS;

@RunWith(IoTDBTestRunner.class)
@Category({TableLocalStandaloneIT.class, TableClusterIT.class})
public class IoTDBUncorrelatedSubqueryInWhereClauseIT {

@BeforeClass
public static void setUp() throws Exception {
EnvFactory.getEnv().getConfig().getCommonConfig().setSortBufferSize(128 * 1024);
EnvFactory.getEnv().getConfig().getCommonConfig().setMaxTsBlockSizeInByte(4 * 1024);
EnvFactory.getEnv().initClusterEnvironment();
prepareTableData(CREATE_SQLS);
}

@AfterClass
public static void tearDown() throws Exception {
EnvFactory.getEnv().cleanClusterEnvironment();
}

@Test
public void testScalarSubqueryAfterComparisonInOneTable() {
String sql;
String[] expectedHeader;
String[] retArray;

// Test case: s equals to the maximum value of s in table1
sql =
"SELECT cast(%s AS INT32) as %s FROM table1 WHERE device_id = 'd01' and %s = (SELECT max(%s) from table1 WHERE device_id = 'd01')";
retArray = new String[] {"70,"};
for (String measurement : NUMERIC_MEASUREMENTS) {
expectedHeader = new String[] {measurement};
tableResultSetEqualTest(
String.format(sql, measurement, measurement, measurement, measurement),
expectedHeader,
retArray,
DATABASE_NAME);
}

// Test case: s not equals to the maximum value of s in table1
sql =
"SELECT cast(%s AS INT32) as %s FROM table1 WHERE device_id = 'd01' and %s != ((SELECT max(%s) FROM table1 WHERE device_id = 'd01'))";
retArray = new String[] {"30,", "40,", "50,", "60,"};
for (String measurement : NUMERIC_MEASUREMENTS) {
expectedHeader = new String[] {measurement};
tableResultSetEqualTest(
String.format(sql, measurement, measurement, measurement, measurement),
expectedHeader,
retArray,
DATABASE_NAME);
}

// Test case: s greater than the average value of s in table1
sql =
"SELECT cast(%s AS INT32) as %s FROM table1 WHERE device_id = 'd01' and %s >= ((SELECT AVG(%s) FROM table1 WHERE device_id = 'd01'))";
retArray = new String[] {"50,", "60,", "70,"};
for (String measurement : NUMERIC_MEASUREMENTS) {
expectedHeader = new String[] {measurement};
tableResultSetEqualTest(
String.format(sql, measurement, measurement, measurement, measurement),
expectedHeader,
retArray,
DATABASE_NAME);
}

// Test case: s greater than the max value of s in table1
sql =
"SELECT cast(%s AS INT32) as %s FROM table1 WHERE device_id = 'd01' and %s > ((SELECT max(%s) FROM table1 WHERE device_id = 'd01'))";
retArray = new String[] {};
for (String measurement : NUMERIC_MEASUREMENTS) {
expectedHeader = new String[] {measurement};
tableResultSetEqualTest(
String.format(sql, measurement, measurement, measurement, measurement),
expectedHeader,
retArray,
DATABASE_NAME);
}

// Test case: s is less than the maximum value of s in table1 and greater than the minimum value
// of s in table1
sql =
"SELECT cast(%s AS INT32) as %s FROM table1 WHERE device_id = 'd01' and %s < (SELECT max(%s) from table1 WHERE device_id = 'd01') and %s > (SELECT min(%s) from table1 WHERE device_id = 'd01') ";
retArray = new String[] {"40,", "50,", "60,"};
for (String measurement : NUMERIC_MEASUREMENTS) {
expectedHeader = new String[] {measurement};
tableResultSetEqualTest(
String.format(
sql, measurement, measurement, measurement, measurement, measurement, measurement),
expectedHeader,
retArray,
DATABASE_NAME);
}

// Test case: s greater than the avg value of s in table1 and s5 = true
sql =
"SELECT cast(%s AS INT32) as %s FROM table1 WHERE device_id = 'd01' and %s > ((SELECT avg(%s) FROM table1 WHERE device_id = 'd01' and s5 = true))";
retArray = new String[] {"60,", "70,"};
for (String measurement : NUMERIC_MEASUREMENTS) {
expectedHeader = new String[] {measurement};
tableResultSetEqualTest(
String.format(sql, measurement, measurement, measurement, measurement),
expectedHeader,
retArray,
DATABASE_NAME);
}

// Test case: s greater than the count value of s in table1
sql =
"SELECT cast(%s AS INT32) as %s FROM table1 WHERE device_id = 'd01' and %s > (SELECT count(%s) FROM table1 WHERE device_id = 'd01')";
retArray = new String[] {"30,", "40,", "50,", "60,", "70,"};
for (String measurement : NUMERIC_MEASUREMENTS) {
expectedHeader = new String[] {measurement};
tableResultSetEqualTest(
String.format(sql, measurement, measurement, measurement, measurement),
expectedHeader,
retArray,
DATABASE_NAME);
}

// Test case: s less than the sum value of s in table1
sql =
"SELECT cast(%s AS INT32) as %s FROM table1 WHERE device_id = 'd01' and %s < (SELECT sum(%s) FROM table1 WHERE device_id = 'd01')";
retArray = new String[] {"30,", "40,", "50,", "60,", "70,"};
for (String measurement : NUMERIC_MEASUREMENTS) {
expectedHeader = new String[] {measurement};
tableResultSetEqualTest(
String.format(sql, measurement, measurement, measurement, measurement),
expectedHeader,
retArray,
DATABASE_NAME);
}

// Test case: subquery is not aggregation but returns exactly one row
sql =
"SELECT cast(%s AS INT32) as %s FROM table1 WHERE device_id = 'd01' and %s = (SELECT %s FROM table1 WHERE device_id = 'd01' and %s = 30)";
retArray = new String[] {"30,"};
for (String measurement : NUMERIC_MEASUREMENTS) {
expectedHeader = new String[] {measurement};
tableResultSetEqualTest(
String.format(
sql, measurement, measurement, measurement, measurement, measurement, measurement),
expectedHeader,
retArray,
DATABASE_NAME);
}
}

@Test
public void testScalarSubqueryAfterComparisonInDifferentTables() {
String sql;
String[] expectedHeader;
String[] retArray;

// Test case: s greater than the count value of s in table2
sql =
"SELECT cast(%s AS INT32) as %s FROM table1 WHERE device_id = 'd01' and %s > (SELECT count(%s) from table2)";
retArray = new String[] {"30,", "40,", "50,", "60,", "70,"};
for (String measurement : NUMERIC_MEASUREMENTS) {
expectedHeader = new String[] {measurement};
tableResultSetEqualTest(
String.format(sql, measurement, measurement, measurement, measurement),
expectedHeader,
retArray,
DATABASE_NAME);
}

// Test case: s less than the max value of s in table2 * the count value of s in table2 * 10
sql =
"SELECT cast(%s AS INT32) as %s FROM table1 WHERE device_id = 'd01' and %s < ((SELECT max(%s) from table2) * (SELECT count(%s) from table2)) * 10";
retArray = new String[] {"30,", "40,", "50,", "60,", "70,"};
for (String measurement : NUMERIC_MEASUREMENTS) {
expectedHeader = new String[] {measurement};
tableResultSetEqualTest(
String.format(sql, measurement, measurement, measurement, measurement, measurement),
expectedHeader,
retArray,
DATABASE_NAME);
}
}

@Test
public void testNestedScalarSubqueryAfterComparison() {
String sql;
String[] expectedHeader;
String[] retArray;

// Test case: nested scalar subquery in where clause
sql =
"SELECT cast(%s AS INT32) as %s FROM table1 WHERE device_id = 'd01' and %s = (SELECT max(%s) from table1 where %s = (SELECT max(%s) from table1 WHERE device_id = 'd01'))";
retArray = new String[] {"70,"};
for (String measurement : NUMERIC_MEASUREMENTS) {
expectedHeader = new String[] {measurement};
tableResultSetEqualTest(
String.format(
sql, measurement, measurement, measurement, measurement, measurement, measurement),
expectedHeader,
retArray,
DATABASE_NAME);
}

// Test case: nested scalar subquery with table subquery
sql =
"SELECT %s from (SELECT cast(%s AS INT32) as %s FROM table1 WHERE device_id = 'd01' and %s = (SELECT max(%s) from table1 where %s = (SELECT max(%s) from table1 WHERE device_id = 'd01')))";
retArray = new String[] {"70,"};
for (String measurement : NUMERIC_MEASUREMENTS) {
expectedHeader = new String[] {measurement};
tableResultSetEqualTest(
String.format(
sql,
measurement,
measurement,
measurement,
measurement,
measurement,
measurement,
measurement),
expectedHeader,
retArray,
DATABASE_NAME);
}
}

@Test
public void testScalarSubqueryAfterComparisonLegalityCheck() {
// Legality check: subquery returns multiple rows (should fail)
tableAssertTestFail(
"select s1 from table1 where s1 = (select s1 from table1)",
"301: Scalar sub-query has returned multiple rows.",
DATABASE_NAME);

// Legality check: subquery can not be parsed
tableAssertTestFail(
"select s1 from table1 where s1 = (select s1 from)", "mismatched input", DATABASE_NAME);

// Legality check: subquery can not be parsed(without parentheses)
tableAssertTestFail(
"select s1 from table1 where s1 = select s1 from table1",
"mismatched input",
DATABASE_NAME);

// Legality check: Main query can not be parsed
tableAssertTestFail(
"select s1 from table1 where s1 = (select max(s1) from table1) and",
"mismatched input",
DATABASE_NAME);
}
}
Loading
Loading