Skip to content

Commit

Permalink
Merge pull request #5436 from sundy-li/agg-null
Browse files Browse the repository at this point in the history
fix(functions): make aggregate function sum/avg/min/max support null โ€ฆ
  • Loading branch information
BohuTANG authored May 18, 2022
2 parents 2e78aa1 + 4f5ca87 commit 5e7a56b
Show file tree
Hide file tree
Showing 13 changed files with 122 additions and 89 deletions.
9 changes: 8 additions & 1 deletion common/functions/src/aggregates/aggregate_avg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,14 @@ pub fn try_create_aggregate_avg_function(
if data_type.data_type_id() == TypeID::Boolean {
return AggregateAvgFunction::<u8, u64>::try_create(display_name, arguments);
}
with_match_primitive_type_id!(data_type.data_type_id(), |$T| {

let mut phid = data_type.data_type_id();
// null use dummy func, it's already covered in `AggregateNullResultFunction`
if data_type.is_null() {
phid = TypeID::UInt8;
}

with_match_primitive_type_id!(phid, |$T| {
AggregateAvgFunction::<$T, <$T as PrimitiveType>::LargestType>::try_create(
display_name,
arguments,
Expand Down
8 changes: 7 additions & 1 deletion common/functions/src/aggregates/aggregate_min_max.rs
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,13 @@ pub fn try_create_aggregate_minmax_function<const IS_MIN: bool>(
) -> Result<Arc<dyn AggregateFunction>> {
assert_unary_arguments(display_name, arguments.len())?;
let data_type = arguments[0].data_type().clone();
let phid = data_type.data_type_id().to_physical_type();
let mut phid = data_type.data_type_id().to_physical_type();

// null use dummy func, it's already covered in `AggregateNullResultFunction`
if data_type.is_null() {
phid = PhysicalTypeID::UInt8;
}

let result = with_match_scalar_types_error!(phid, |$T| {
if IS_MIN {
type State = ScalarState<$T, CmpMin>;
Expand Down
9 changes: 8 additions & 1 deletion common/functions/src/aggregates/aggregate_sum.rs
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,14 @@ pub fn try_create_aggregate_sum_function(
if data_type.data_type_id() == TypeID::Boolean {
return AggregateSumFunction::<u8, u64>::try_create(display_name, arguments);
}
with_match_primitive_type_id!(data_type.data_type_id(), |$T| {

let mut phid = data_type.data_type_id();
// null use dummy func, it's already covered in `AggregateNullResultFunction`
if data_type.is_null() {
phid = TypeID::UInt8;
}

with_match_primitive_type_id!(phid, |$T| {
AggregateSumFunction::<$T, <$T as PrimitiveType>::LargestType>::try_create(
display_name,
arguments,
Expand Down
4 changes: 3 additions & 1 deletion tests/logictest/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
'database': 'default',
}


def config_from_env():
mysql_host = os.getenv("QUERY_MYSQL_HANDLER_HOST")
if mysql_host is not None:
Expand All @@ -37,5 +38,6 @@ def config_from_env():
if mysql_user is not None:
mysql_config['user'] = mysql_user
http_config['user'] = mysql_user



config_from_env()
75 changes: 43 additions & 32 deletions tests/logictest/gen_suites.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def get_database(line):
# but only parse .sql file, .result file will be ignore, run sql and fetch results
# need a local running databend-meta and databend-query or change config.py to your cluster
def get_all_cases():
# copy from databend-test
# copy from databend-test
def collect_subdirs_with_pattern(cur_dir_path, pattern):
return list(
# Make sure all sub-dir name starts with [0-9]+_*.
Expand Down Expand Up @@ -112,28 +112,32 @@ def mysql_fetch_results(sql):
try:
mysql_client.execute(sql)
r = mysql_client.fetchall()

for row in r:
rowlist = []
for item in row:
rowlist.append(str(item))
row_string = " ".join(rowlist)
if len(row_string) == 0: # empty line replace with tab
if len(row_string) == 0: # empty line replace with tab
row_string = "\t"
ret = ret + row_string + "\n"
except Exception as err:
log.warning("SQL: {} fetch no results, msg:{} ,check it manual.".format(sql,str(err)))
log.warning(
"SQL: {} fetch no results, msg:{} ,check it manual.".format(
sql, str(err)))
return ret

target_dir = os.path.dirname(str.replace(sql_file,suite_path,logictest_path))
target_dir = os.path.dirname(
str.replace(sql_file, suite_path, logictest_path))
case_name = os.path.splitext(os.path.basename(sql_file))[0]
target_file = os.path.join(target_dir,case_name)
target_file = os.path.join(target_dir, case_name)

if skip_exist and os.path.exists(target_file):
log.warning("skip case file {}, already exist.".format(target_file))
return

log.info("Write test case to path: {}, case name is {}".format(target_dir, case_name))
log.info("Write test case to path: {}, case name is {}".format(
target_dir, case_name))

content_output = ""
f = open(sql_file, encoding='UTF-8')
Expand All @@ -142,12 +146,12 @@ def mysql_fetch_results(sql):
if is_empty_line(line):
continue

if line.startswith("--"): # pass comment
continue
if line.startswith("--"): # pass comment
continue

# multi line sql
sql_content = sql_content + line.rstrip()
if ';' not in line:
if ';' not in line:
continue

statement = sql_content.strip()
Expand All @@ -157,14 +161,13 @@ def mysql_fetch_results(sql):
errorStatment = get_error_statment(statement)
if errorStatment != None:
content_output = content_output + STATEMENT_ERROR.format(
error_id = errorStatment.group("expectError"),
statement = errorStatment.group("statement")
)
error_id=errorStatment.group("expectError"),
statement=errorStatment.group("statement"))
continue

if str.lower(first_word(statement)) in query_statment_first_words:
if str.lower(first_word(statement)) in query_statment_first_words:
# query statement

try:
http_results = format_result(http_client.fetch_all(statement))
query_options = http_client.get_query_option()
Expand All @@ -174,35 +177,37 @@ def mysql_fetch_results(sql):
continue

if query_options == "":
log.warning("statement: {} type query could not get query_option change to ok statement".format(statement))
content_output = content_output + STATEMENT_OK.format(statement = statement)
log.warning(
"statement: {} type query could not get query_option change to ok statement"
.format(statement))
content_output = content_output + STATEMENT_OK.format(
statement=statement)
continue

mysql_results = mysql_fetch_results(statement)
labels = ""

log.debug("sql: " + statement)
log.debug("mysql return: " + mysql_results)
log.debug("http return: "+ http_results)
log.debug("http return: " + http_results)

if http_results is not None and mysql_results != http_results:
case_results = RESULTS_TEMPLATE.format(
results_string = mysql_results, label = "mysql")
results_string=mysql_results, label="mysql")

case_results = case_results + "\n" + RESULTS_TEMPLATE.format(
results_string = http_results, label = "http")
results_string=http_results, label="http")

labels = "label(mysql,http)"
else:
case_results = RESULTS_TEMPLATE.format(
results_string = mysql_results, label = "")
results_string=mysql_results, label="")

content_output = content_output + STATEMENT_QUERY.format(
query_options = query_options,
statement = statement,
results = case_results,
labels = labels
)
query_options=query_options,
statement=statement,
results=case_results,
labels=labels)
else:
# ok statement
try:
Expand All @@ -218,10 +223,12 @@ def mysql_fetch_results(sql):
http_client.query_with_session(statement)
mysql_client.execute(statement)
except Exception as err:
log.warning("statement {} excute error,msg {}".format(statement, str(err)))
log.warning("statement {} excute error,msg {}".format(
statement, str(err)))
pass

content_output = content_output + STATEMENT_OK.format(statement = statement)
content_output = content_output + STATEMENT_OK.format(
statement=statement)

f.close()
if not os.path.exists(target_dir):
Expand All @@ -231,6 +238,7 @@ def mysql_fetch_results(sql):
caseFile.write(content_output)
caseFile.close()


def output():
print("=================================")
print("Exception sql using Http handler:")
Expand All @@ -242,14 +250,15 @@ def output():
print("\n".join(manual_cases))
print("=================================")


def main():
all_cases = get_all_cases()

for file in all_cases:
# .result will be ignore
if '.result' in file or '.result_filter' in file:
continue
continue

# .py .sh will be ignore, need log
if ".py" in file or ".sh" in file:
manual_cases.append(file)
Expand All @@ -260,8 +269,10 @@ def main():
time.sleep(0.01)

output()


if __name__ == '__main__':
log.info("Start generate sqllogictest suites from path: {} to path: {}".format(suite_path, logictest_path))
log.info(
"Start generate sqllogictest suites from path: {} to path: {}".format(
suite_path, logictest_path))
main()
Loading

0 comments on commit 5e7a56b

Please sign in to comment.