Skip to content

Commit

Permalink
run substrait tests
Browse files Browse the repository at this point in the history
* Also, modify snowflake runner to run with private key
  • Loading branch information
srikrishnak committed Nov 27, 2024
1 parent d2efd80 commit 6b88940
Show file tree
Hide file tree
Showing 7 changed files with 104 additions and 40 deletions.
12 changes: 9 additions & 3 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,13 @@ jobs:
uses: actions/checkout@v3
with:
submodules: recursive

- name: Check for specific files
run: |
echo 'Found files:'
find . -type f -name 'convert_testcase_helper.py' -print
find . -type f -name 'case_file_parser.py' -print
find . -type f -name 'nodes.py' -print
- run: pip install -r requirements.txt
- name: Build & run
run: docker run --rm $(docker build -q --file ./ci/docker/sqlite.Dockerfile .)
duckdb:
Expand All @@ -25,7 +31,7 @@ jobs:
uses: actions/checkout@v3
with:
submodules: recursive

- run: pip install -r requirements.txt
- name: Build & run
run: docker run --rm $(docker build -q --file ./ci/docker/duckdb.Dockerfile .)
datafusion:
Expand All @@ -36,7 +42,7 @@ jobs:
uses: actions/checkout@v3
with:
submodules: recursive

- run: pip install -r requirements.txt
- name: Build & run
run: docker run --rm $(docker build -q --file ./ci/docker/datafusion.Dockerfile .)
postgres:
Expand Down
42 changes: 31 additions & 11 deletions bft/testers/duckdb/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,29 @@ def is_datetype(arg):
return type(arg) in [datetime.datetime, datetime.date, datetime.timedelta]


def has_only_date(value: datetime.datetime):
if (
value.hour == 0
and value.minute == 0
and value.second == 0
and value.microsecond == 0
):
return True
return False


def datetype_value_equal(result, case_result):
if str(result) == case_result:
return True
if (
isinstance(result, datetime.datetime)
and has_only_date(result)
and str(result.date()) == case_result
):
return True
return False


class DuckDBRunner(SqlCaseRunner):
def __init__(self, dialect):
super().__init__(dialect)
Expand All @@ -62,31 +85,26 @@ def __init__(self, dialect):
def run_sql_case(self, case: Case, mapping: SqlMapping) -> SqlCaseResult:

try:
max_args = len(case.args) + 1
if case.function == 'regexp_replace':
max_args = 3
if case.function == 'regexp_match_substring':
max_args = 2
arg_defs = [
f"arg{idx} {type_to_duckdb_type(arg.type)}"
for idx, arg in enumerate(case.args[:max_args])
for idx, arg in enumerate(case.args)
]
schema = ",".join(arg_defs)
self.conn.execute(f"CREATE TABLE my_table({schema});")
self.conn.execute(f"SET TimeZone='UTC';")

arg_names = [f"arg{idx}" for idx in range(len(case.args[:max_args]))]
arg_names = [f"arg{idx}" for idx in range(len(case.args))]
joined_arg_names = ",".join(arg_names)
arg_vals_list = list()
for arg in case.args[:max_args]:
for arg in case.args:
if is_string_type(arg):
arg_vals_list.append("'" + literal_to_str(arg.value) + "'")
else:
arg_vals_list.append(literal_to_str(arg.value))
arg_vals = ", ".join(arg_vals_list)
if mapping.aggregate:
arg_vals_list = list()
for arg in case.args[:max_args]:
for arg in case.args:
arg_vals = ""
for value in arg.value:
if is_string_type(arg):
Expand Down Expand Up @@ -119,7 +137,7 @@ def run_sql_case(self, case: Case, mapping: SqlMapping) -> SqlCaseResult:
if len(arg_names) != 2:
raise Exception(f"Extract function with {len(arg_names)} args")
expr = f"SELECT {mapping.local_name}({arg_vals_list[0]} FROM {arg_names[1]}) FROM my_table;"
elif mapping.local_name == 'count(*)':
elif mapping.local_name == "count(*)":
expr = f"SELECT {mapping.local_name} FROM my_table;"
elif mapping.aggregate:
if len(arg_names) < 1:
Expand Down Expand Up @@ -147,7 +165,9 @@ def run_sql_case(self, case: Case, mapping: SqlMapping) -> SqlCaseResult:
else:
if result == case.result.value:
return SqlCaseResult.success()
elif is_datetype(result) and str(result) == case.result.value:
elif is_datetype(result) and datetype_value_equal(
result, case.result.value
):
return SqlCaseResult.success()
else:
return SqlCaseResult.mismatch(str(result))
Expand Down
44 changes: 25 additions & 19 deletions bft/testers/snowflake/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import os
import yaml
from typing import Dict, NamedTuple
from cryptography.hazmat.primitives.serialization import load_der_private_key
from cryptography.hazmat.backends import default_backend

from snowflake.connector import connect
from snowflake.connector.errors import Error
Expand All @@ -28,6 +30,7 @@
def type_to_snowflake_type(type: str):
return type_to_dialect_type(type, type_map)


def literal_to_str(lit: str | int | float):
if lit is None:
return "null"
Expand All @@ -41,9 +44,9 @@ def literal_to_str(lit: str | int | float):


def literal_to_float(lit: str | int | float):
if lit in [float('inf'), 'inf']:
if lit in [float("inf"), "inf"]:
return "TO_DOUBLE('inf'::float)"
elif lit in [float('-inf'), '-inf']:
elif lit in [float("-inf"), "-inf"]:
return "TO_DOUBLE('-inf'::float)"
return lit

Expand All @@ -52,11 +55,10 @@ def is_float_type(arg):
return arg.type in ["fp32", "fp64"]



def is_string_type(arg):
return (
arg.type in ["string", "timestamp", "timestamp_tz", "date", "time"]
and arg.value is not None
arg.type in ["string", "timestamp", "timestamp_tz", "date", "time"]
and arg.value is not None
)


Expand All @@ -67,20 +69,24 @@ def is_datetype(arg):
class SnowflakeRunner(SqlCaseRunner):
def __init__(self, dialect):
super().__init__(dialect)
with open('bft/testers/snowflake/config.yaml', 'r') as file:
with open("testers/snowflake/config.yaml", "r") as file:
config = yaml.safe_load(file)
sf_config = config['snowflake']
print(
f"Connecting to {sf_config['account']} as {sf_config['username']}")
self.conn = connect(user=sf_config['username'],
password=os.environ['SNOWSQL_PWD'],
account=sf_config['account'],
database=sf_config['database'],
schema=sf_config['schema'],
# host=sf_config['hostname'].get(""),
# role=sf_config['role'].get(""),
warehouse=sf_config['warehouse']
)
sf_config = config["snowflake"]
print(f"Connecting to {sf_config['account']} as {sf_config['username']}")
private_key_path = os.environ["SNOWSQL_PRIVATE_KEY_PATH"]
with open(private_key_path, "rb") as f:
private_key = f.read()

self.conn = connect(
user=sf_config["username"],
private_key=private_key,
account=sf_config["account"],
database=sf_config["database"],
schema=sf_config["schema"],
# host=sf_config['hostname'].get(""),
# role=sf_config['role'].get(""),
warehouse=sf_config["warehouse"],
)

def run_sql_case(self, case: Case, mapping: SqlMapping) -> SqlCaseResult:

Expand Down Expand Up @@ -147,7 +153,7 @@ def run_sql_case(self, case: Case, mapping: SqlMapping) -> SqlCaseResult:
if len(arg_names) != 2:
raise Exception(f"Extract function with {len(arg_names)} args")
expr = f"SELECT {mapping.local_name}({arg_vals_list[0]} FROM {arg_names[1]}) FROM my_table;"
elif mapping.local_name == 'count(*)':
elif mapping.local_name == "count(*)":
expr = f"SELECT {mapping.local_name} FROM my_table;"
elif mapping.aggregate:
if len(arg_names) < 1:
Expand Down
28 changes: 26 additions & 2 deletions bft/tests/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,24 +5,48 @@

from bft.cases.parser import CaseFileParser
from bft.cases.types import Case
from bft.dialects.types import DialectsLibrary
from bft.testers.base_tester import BaseTester
from tools.convert_testcases.convert_testcases_to_yaml_format import (
convert_directory as convert_directory_from_substrait,
)


# Would be nice to have this as a session-scoped fixture but it doesn't seem that
# parameter values can be a fixture
def cases() -> List[Case]:
cases = []
bft_dir = Path(__file__).parent.parent.parent
parser = CaseFileParser()
cases_dir = Path(__file__) / ".." / ".." / ".." / "cases"
cases_dir = bft_dir / "cases"
substrait_cases_dir = bft_dir / "substrait" / "tests" / "cases"
convert_directory_from_substrait(substrait_cases_dir, cases_dir)
for case_path in cases_dir.resolve().rglob("*.yaml"):
with open(case_path, "rb") as case_f:
for case_file in parser.parse(case_f):
for case in case_file.cases:
case = transform_case(case)
cases.append(case)
return cases


def transform_case(case):
args_len = len(case.args) + 1
if case.function == "regexp_replace":
args_len = 3
elif case.function == "regexp_match_substring":
args_len = 2

# Create a new Case instance with updated `args`
return Case(
function=case.function,
base_uri=case.base_uri,
group=case.group,
args=case.args[:args_len], # Update args here
result=case.result,
options=case.options,
)


def case_id_fn(case: Case):
return f"{case.function}_{case.group.id}_{case.group.index}"

Expand Down
12 changes: 10 additions & 2 deletions ci/docker/sqlite.Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,17 @@ FROM alpine:3.18
ENV PYTHONUNBUFFERED=1
RUN apk add --update --no-cache python3 && ln -sf python3 /usr/bin/python
RUN python3 -m ensurepip
RUN pip3 install --no-cache --upgrade pip setuptools pytest pyyaml mistletoe
RUN pip3 install --no-cache --upgrade pip setuptools pytest pyyaml mistletoe ruamel.yaml antlr4-python3-runtime pytz

WORKDIR /bft
COPY . .

CMD /usr/bin/python -mpytest bft/tests/test_sqlite.py
# Ensure Python finds the substrait directory
ENV PYTHONPATH=/bft/substrait

# CMD to run all commands and display the results
CMD sh -c "echo 'Found files:' && find /bft -type f -name 'convert_testcase_helper.py' -print && \
find . -type f -name 'case_file_parser.py' -print && \
find /bft -type f -name 'nodes.py' -print && \
echo \$PYTHONPATH && \
/usr/bin/python -mpytest bft/tests/test_sqlite.py"
2 changes: 1 addition & 1 deletion substrait
Submodule substrait updated 204 files
4 changes: 2 additions & 2 deletions tools/convert_testcases/convert_testcases_to_yaml_format.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import os

from ruamel.yaml import YAML
from tests.coverage.nodes import (
from substrait.tests.coverage.nodes import (
TestFile,
AggregateArgument,
)
from tests.coverage.case_file_parser import load_all_testcases
from substrait.tests.coverage.case_file_parser import load_all_testcases
from tools.convert_testcases.convert_testcase_helper import (
convert_to_yaml_value,
convert_to_long_type,
Expand Down

0 comments on commit 6b88940

Please sign in to comment.