Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(functions): make aggregate function sum/avg/min/max support null … #5436

Merged
merged 2 commits into from
May 18, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion common/functions/src/aggregates/aggregate_avg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,14 @@ pub fn try_create_aggregate_avg_function(
if data_type.data_type_id() == TypeID::Boolean {
return AggregateAvgFunction::<u8, u64>::try_create(display_name, arguments);
}
with_match_primitive_type_id!(data_type.data_type_id(), |$T| {

let mut phid = data_type.data_type_id();
// null use dummy func, it's already covered in `AggregateNullResultFunction`
if data_type.is_null() {
phid = TypeID::UInt8;
}

with_match_primitive_type_id!(phid, |$T| {
AggregateAvgFunction::<$T, <$T as PrimitiveType>::LargestType>::try_create(
display_name,
arguments,
Expand Down
8 changes: 7 additions & 1 deletion common/functions/src/aggregates/aggregate_min_max.rs
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,13 @@ pub fn try_create_aggregate_minmax_function<const IS_MIN: bool>(
) -> Result<Arc<dyn AggregateFunction>> {
assert_unary_arguments(display_name, arguments.len())?;
let data_type = arguments[0].data_type().clone();
let phid = data_type.data_type_id().to_physical_type();
let mut phid = data_type.data_type_id().to_physical_type();

// null use dummy func, it's already covered in `AggregateNullResultFunction`
if data_type.is_null() {
phid = PhysicalTypeID::UInt8;
}

let result = with_match_scalar_types_error!(phid, |$T| {
if IS_MIN {
type State = ScalarState<$T, CmpMin>;
Expand Down
9 changes: 8 additions & 1 deletion common/functions/src/aggregates/aggregate_sum.rs
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,14 @@ pub fn try_create_aggregate_sum_function(
if data_type.data_type_id() == TypeID::Boolean {
return AggregateSumFunction::<u8, u64>::try_create(display_name, arguments);
}
with_match_primitive_type_id!(data_type.data_type_id(), |$T| {

let mut phid = data_type.data_type_id();
// null use dummy func, it's already covered in `AggregateNullResultFunction`
if data_type.is_null() {
phid = TypeID::UInt8;
}

with_match_primitive_type_id!(phid, |$T| {
AggregateSumFunction::<$T, <$T as PrimitiveType>::LargestType>::try_create(
display_name,
arguments,
Expand Down
4 changes: 3 additions & 1 deletion tests/logictest/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
'database': 'default',
}


def config_from_env():
mysql_host = os.getenv("QUERY_MYSQL_HANDLER_HOST")
if mysql_host is not None:
Expand All @@ -37,5 +38,6 @@ def config_from_env():
if mysql_user is not None:
mysql_config['user'] = mysql_user
http_config['user'] = mysql_user



config_from_env()
75 changes: 43 additions & 32 deletions tests/logictest/gen_suites.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def get_database(line):
# but only parse .sql file, .result file will be ignore, run sql and fetch results
# need a local running databend-meta and databend-query or change config.py to your cluster
def get_all_cases():
# copy from databend-test
# copy from databend-test
def collect_subdirs_with_pattern(cur_dir_path, pattern):
return list(
# Make sure all sub-dir name starts with [0-9]+_*.
Expand Down Expand Up @@ -112,28 +112,32 @@ def mysql_fetch_results(sql):
try:
mysql_client.execute(sql)
r = mysql_client.fetchall()

for row in r:
rowlist = []
for item in row:
rowlist.append(str(item))
row_string = " ".join(rowlist)
if len(row_string) == 0: # empty line replace with tab
if len(row_string) == 0: # empty line replace with tab
row_string = "\t"
ret = ret + row_string + "\n"
except Exception as err:
log.warning("SQL: {} fetch no results, msg:{} ,check it manual.".format(sql,str(err)))
log.warning(
"SQL: {} fetch no results, msg:{} ,check it manual.".format(
sql, str(err)))
return ret

target_dir = os.path.dirname(str.replace(sql_file,suite_path,logictest_path))
target_dir = os.path.dirname(
str.replace(sql_file, suite_path, logictest_path))
case_name = os.path.splitext(os.path.basename(sql_file))[0]
target_file = os.path.join(target_dir,case_name)
target_file = os.path.join(target_dir, case_name)

if skip_exist and os.path.exists(target_file):
log.warning("skip case file {}, already exist.".format(target_file))
return

log.info("Write test case to path: {}, case name is {}".format(target_dir, case_name))
log.info("Write test case to path: {}, case name is {}".format(
target_dir, case_name))

content_output = ""
f = open(sql_file, encoding='UTF-8')
Expand All @@ -142,12 +146,12 @@ def mysql_fetch_results(sql):
if is_empty_line(line):
continue

if line.startswith("--"): # pass comment
continue
if line.startswith("--"): # pass comment
continue

# multi line sql
sql_content = sql_content + line.rstrip()
if ';' not in line:
if ';' not in line:
continue

statement = sql_content.strip()
Expand All @@ -157,14 +161,13 @@ def mysql_fetch_results(sql):
errorStatment = get_error_statment(statement)
if errorStatment != None:
content_output = content_output + STATEMENT_ERROR.format(
error_id = errorStatment.group("expectError"),
statement = errorStatment.group("statement")
)
error_id=errorStatment.group("expectError"),
statement=errorStatment.group("statement"))
continue

if str.lower(first_word(statement)) in query_statment_first_words:
if str.lower(first_word(statement)) in query_statment_first_words:
# query statement

try:
http_results = format_result(http_client.fetch_all(statement))
query_options = http_client.get_query_option()
Expand All @@ -174,35 +177,37 @@ def mysql_fetch_results(sql):
continue

if query_options == "":
log.warning("statement: {} type query could not get query_option change to ok statement".format(statement))
content_output = content_output + STATEMENT_OK.format(statement = statement)
log.warning(
"statement: {} type query could not get query_option change to ok statement"
.format(statement))
content_output = content_output + STATEMENT_OK.format(
statement=statement)
continue

mysql_results = mysql_fetch_results(statement)
labels = ""

log.debug("sql: " + statement)
log.debug("mysql return: " + mysql_results)
log.debug("http return: "+ http_results)
log.debug("http return: " + http_results)

if http_results is not None and mysql_results != http_results:
case_results = RESULTS_TEMPLATE.format(
results_string = mysql_results, label = "mysql")
results_string=mysql_results, label="mysql")

case_results = case_results + "\n" + RESULTS_TEMPLATE.format(
results_string = http_results, label = "http")
results_string=http_results, label="http")

labels = "label(mysql,http)"
else:
case_results = RESULTS_TEMPLATE.format(
results_string = mysql_results, label = "")
results_string=mysql_results, label="")

content_output = content_output + STATEMENT_QUERY.format(
query_options = query_options,
statement = statement,
results = case_results,
labels = labels
)
query_options=query_options,
statement=statement,
results=case_results,
labels=labels)
else:
# ok statement
try:
Expand All @@ -218,10 +223,12 @@ def mysql_fetch_results(sql):
http_client.query_with_session(statement)
mysql_client.execute(statement)
except Exception as err:
log.warning("statement {} excute error,msg {}".format(statement, str(err)))
log.warning("statement {} excute error,msg {}".format(
statement, str(err)))
pass

content_output = content_output + STATEMENT_OK.format(statement = statement)
content_output = content_output + STATEMENT_OK.format(
statement=statement)

f.close()
if not os.path.exists(target_dir):
Expand All @@ -231,6 +238,7 @@ def mysql_fetch_results(sql):
caseFile.write(content_output)
caseFile.close()


def output():
print("=================================")
print("Exception sql using Http handler:")
Expand All @@ -242,14 +250,15 @@ def output():
print("\n".join(manual_cases))
print("=================================")


def main():
all_cases = get_all_cases()

for file in all_cases:
# .result will be ignore
if '.result' in file or '.result_filter' in file:
continue
continue

# .py .sh will be ignore, need log
if ".py" in file or ".sh" in file:
manual_cases.append(file)
Expand All @@ -260,8 +269,10 @@ def main():
time.sleep(0.01)

output()


if __name__ == '__main__':
log.info("Start generate sqllogictest suites from path: {} to path: {}".format(suite_path, logictest_path))
log.info(
"Start generate sqllogictest suites from path: {} to path: {}".format(
suite_path, logictest_path))
main()
Loading