Skip to content

Commit

Permalink
run substrait tests
Browse files Browse the repository at this point in the history
* Also, modify snowflake runner to run with private key
  • Loading branch information
srikrishnak committed Nov 27, 2024
1 parent d2efd80 commit e3cb637
Show file tree
Hide file tree
Showing 17 changed files with 96 additions and 52 deletions.
22 changes: 9 additions & 13 deletions bft/testers/duckdb/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from bft.cases.runner import SqlCaseResult, SqlCaseRunner
from bft.cases.types import Case
from bft.dialects.types import SqlMapping
from bft.utils.utils import type_to_dialect_type
from bft.utils.utils import type_to_dialect_type, datetype_value_equal

type_map = {
"i8": "TINYINT",
Expand Down Expand Up @@ -53,7 +53,6 @@ def is_string_type(arg):
def is_datetype(arg):
return type(arg) in [datetime.datetime, datetime.date, datetime.timedelta]


class DuckDBRunner(SqlCaseRunner):
def __init__(self, dialect):
super().__init__(dialect)
Expand All @@ -62,31 +61,26 @@ def __init__(self, dialect):
def run_sql_case(self, case: Case, mapping: SqlMapping) -> SqlCaseResult:

try:
max_args = len(case.args) + 1
if case.function == 'regexp_replace':
max_args = 3
if case.function == 'regexp_match_substring':
max_args = 2
arg_defs = [
f"arg{idx} {type_to_duckdb_type(arg.type)}"
for idx, arg in enumerate(case.args[:max_args])
for idx, arg in enumerate(case.args)
]
schema = ",".join(arg_defs)
self.conn.execute(f"CREATE TABLE my_table({schema});")
self.conn.execute(f"SET TimeZone='UTC';")

arg_names = [f"arg{idx}" for idx in range(len(case.args[:max_args]))]
arg_names = [f"arg{idx}" for idx in range(len(case.args))]
joined_arg_names = ",".join(arg_names)
arg_vals_list = list()
for arg in case.args[:max_args]:
for arg in case.args:
if is_string_type(arg):
arg_vals_list.append("'" + literal_to_str(arg.value) + "'")
else:
arg_vals_list.append(literal_to_str(arg.value))
arg_vals = ", ".join(arg_vals_list)
if mapping.aggregate:
arg_vals_list = list()
for arg in case.args[:max_args]:
for arg in case.args:
arg_vals = ""
for value in arg.value:
if is_string_type(arg):
Expand Down Expand Up @@ -119,7 +113,7 @@ def run_sql_case(self, case: Case, mapping: SqlMapping) -> SqlCaseResult:
if len(arg_names) != 2:
raise Exception(f"Extract function with {len(arg_names)} args")
expr = f"SELECT {mapping.local_name}({arg_vals_list[0]} FROM {arg_names[1]}) FROM my_table;"
elif mapping.local_name == 'count(*)':
elif mapping.local_name == "count(*)":
expr = f"SELECT {mapping.local_name} FROM my_table;"
elif mapping.aggregate:
if len(arg_names) < 1:
Expand Down Expand Up @@ -147,7 +141,9 @@ def run_sql_case(self, case: Case, mapping: SqlMapping) -> SqlCaseResult:
else:
if result == case.result.value:
return SqlCaseResult.success()
elif is_datetype(result) and str(result) == case.result.value:
elif is_datetype(result) and datetype_value_equal(
result, case.result.value
):
return SqlCaseResult.success()
else:
return SqlCaseResult.mismatch(str(result))
Expand Down
5 changes: 4 additions & 1 deletion bft/testers/postgres/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from bft.cases.runner import SqlCaseResult, SqlCaseRunner
from bft.cases.types import Case
from bft.dialects.types import SqlMapping
from bft.utils.utils import datetype_value_equal

type_map = {
"i16": "smallint",
Expand Down Expand Up @@ -152,7 +153,9 @@ def run_sql_case(self, case: Case, mapping: SqlMapping) -> SqlCaseResult:
else:
if result == case.result.value:
return SqlCaseResult.success()
elif is_datetype(result) and str(result) == case.result.value:
elif is_datetype(result) and datetype_value_equal(
result, case.result.value
):
return SqlCaseResult.success()
else:
return SqlCaseResult.mismatch(str(result))
Expand Down
44 changes: 25 additions & 19 deletions bft/testers/snowflake/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import os
import yaml
from typing import Dict, NamedTuple
from cryptography.hazmat.primitives.serialization import load_der_private_key
from cryptography.hazmat.backends import default_backend

from snowflake.connector import connect
from snowflake.connector.errors import Error
Expand All @@ -28,6 +30,7 @@
def type_to_snowflake_type(type: str):
return type_to_dialect_type(type, type_map)


def literal_to_str(lit: str | int | float):
if lit is None:
return "null"
Expand All @@ -41,9 +44,9 @@ def literal_to_str(lit: str | int | float):


def literal_to_float(lit: str | int | float):
if lit in [float('inf'), 'inf']:
if lit in [float("inf"), "inf"]:
return "TO_DOUBLE('inf'::float)"
elif lit in [float('-inf'), '-inf']:
elif lit in [float("-inf"), "-inf"]:
return "TO_DOUBLE('-inf'::float)"
return lit

Expand All @@ -52,11 +55,10 @@ def is_float_type(arg):
return arg.type in ["fp32", "fp64"]



def is_string_type(arg):
return (
arg.type in ["string", "timestamp", "timestamp_tz", "date", "time"]
and arg.value is not None
arg.type in ["string", "timestamp", "timestamp_tz", "date", "time"]
and arg.value is not None
)


Expand All @@ -67,20 +69,24 @@ def is_datetype(arg):
class SnowflakeRunner(SqlCaseRunner):
def __init__(self, dialect):
super().__init__(dialect)
with open('bft/testers/snowflake/config.yaml', 'r') as file:
with open("testers/snowflake/config.yaml", "r") as file:
config = yaml.safe_load(file)
sf_config = config['snowflake']
print(
f"Connecting to {sf_config['account']} as {sf_config['username']}")
self.conn = connect(user=sf_config['username'],
password=os.environ['SNOWSQL_PWD'],
account=sf_config['account'],
database=sf_config['database'],
schema=sf_config['schema'],
# host=sf_config['hostname'].get(""),
# role=sf_config['role'].get(""),
warehouse=sf_config['warehouse']
)
sf_config = config["snowflake"]
print(f"Connecting to {sf_config['account']} as {sf_config['username']}")
private_key_path = os.environ["SNOWSQL_PRIVATE_KEY_PATH"]
with open(private_key_path, "rb") as f:
private_key = f.read()

self.conn = connect(
user=sf_config["username"],
private_key=private_key,
account=sf_config["account"],
database=sf_config["database"],
schema=sf_config["schema"],
# host=sf_config['hostname'].get(""),
# role=sf_config['role'].get(""),
warehouse=sf_config["warehouse"],
)

def run_sql_case(self, case: Case, mapping: SqlMapping) -> SqlCaseResult:

Expand Down Expand Up @@ -147,7 +153,7 @@ def run_sql_case(self, case: Case, mapping: SqlMapping) -> SqlCaseResult:
if len(arg_names) != 2:
raise Exception(f"Extract function with {len(arg_names)} args")
expr = f"SELECT {mapping.local_name}({arg_vals_list[0]} FROM {arg_names[1]}) FROM my_table;"
elif mapping.local_name == 'count(*)':
elif mapping.local_name == "count(*)":
expr = f"SELECT {mapping.local_name} FROM my_table;"
elif mapping.aggregate:
if len(arg_names) < 1:
Expand Down
22 changes: 20 additions & 2 deletions bft/tests/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,24 +5,42 @@

from bft.cases.parser import CaseFileParser
from bft.cases.types import Case
from bft.dialects.types import DialectsLibrary
from bft.testers.base_tester import BaseTester
from tools.convert_testcases.convert_testcases_to_yaml_format import (
convert_directory as convert_directory_from_substrait,
)


# Would be nice to have this as a session-scoped fixture but it doesn't seem that
# parameter values can be a fixture
def cases() -> List[Case]:
cases = []
bft_dir = Path(__file__).parent.parent.parent
parser = CaseFileParser()
cases_dir = Path(__file__) / ".." / ".." / ".." / "cases"
cases_dir = bft_dir / "cases"
substrait_cases_dir = bft_dir / "substrait" / "tests" / "cases"
convert_directory_from_substrait(substrait_cases_dir, cases_dir)
for case_path in cases_dir.resolve().rglob("*.yaml"):
with open(case_path, "rb") as case_f:
for case_file in parser.parse(case_f):
for case in case_file.cases:
case = transform_case(case)
cases.append(case)
return cases


def transform_case(case):
# Create a new Case instance with updated `args`
return Case(
function=case.function,
base_uri=case.base_uri,
group=case.group,
args=case.args, # Update args here
result=case.result,
options=case.options,
)


def case_id_fn(case: Case):
return f"{case.function}_{case.group.id}_{case.group.index}"

Expand Down
22 changes: 22 additions & 0 deletions bft/utils/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Dict
import datetime


def type_to_dialect_type(type: str, type_map: Dict[str, str])->str:
Expand All @@ -24,3 +25,24 @@ def type_to_dialect_type(type: str, type_map: Dict[str, str])->str:
return type_val
# transform parameterized type name to have dialect type
return type.replace(type_to_check, type_val).replace("<", "(").replace(">", ")")

def has_only_date(value: datetime.datetime):
if (
value.hour == 0
and value.minute == 0
and value.second == 0
and value.microsecond == 0
):
return True
return False

def datetype_value_equal(result, case_result):
if str(result) == case_result:
return True
if (
isinstance(result, datetime.datetime)
and has_only_date(result)
and str(result.date()) == case_result
):
return True
return False
3 changes: 2 additions & 1 deletion ci/docker/base-tester.Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@ FROM alpine:3.18
ARG PIP_PACKAGES

ENV PYTHONUNBUFFERED=1
ENV PYTHONPATH=/bft/substrait
RUN apk add --update --no-cache python3 && ln -sf python3 /usr/bin/python
RUN python3 -m ensurepip
RUN echo "PIP_PACKAGES is $PIP_PACKAGES"
RUN pip3 install --no-cache --upgrade pip setuptools pytest pyyaml mistletoe $PIP_PACKAGES
RUN pip3 install --no-cache --upgrade pip setuptools pytest pyyaml mistletoe $PIP_PACKAGES ruamel.yaml antlr4-python3-runtime pytz

WORKDIR /bft
COPY . .
Expand Down
3 changes: 2 additions & 1 deletion ci/docker/datafusion.Dockerfile
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
FROM ubuntu:22.04

ENV PYTHONUNBUFFERED=1
ENV PYTHONPATH=/bft/substrait
RUN apt-get update && apt-get install -y python3.10 && ln -sf python3 /usr/bin/python
RUN apt install -y pip
RUN pip install --upgrade pip setuptools pytest pyyaml mistletoe datafusion
RUN pip install --upgrade pip setuptools pytest pyyaml mistletoe datafusion ruamel.yaml antlr4-python3-runtime pytz numpy

WORKDIR /bft
COPY . .
Expand Down
3 changes: 2 additions & 1 deletion ci/docker/duckdb.Dockerfile
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
FROM alpine:3.18

ENV PYTHONUNBUFFERED=1
ENV PYTHONPATH=/bft/substrait
RUN apk add --update --no-cache python3 && ln -sf python3 /usr/bin/python
RUN python3 -m ensurepip
RUN pip3 install --no-cache --upgrade pip setuptools pytest pyyaml mistletoe duckdb
RUN pip3 install --no-cache --upgrade pip setuptools pytest pyyaml mistletoe duckdb ruamel.yaml antlr4-python3-runtime pytz

WORKDIR /bft
COPY . .
Expand Down
3 changes: 2 additions & 1 deletion ci/docker/postgres-server.Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@ ENV POSTGRES_DB=bft
ENV POSTGRES_PASSWORD=postgres

ENV PYTHONUNBUFFERED=1
ENV PYTHONPATH=/bft/substrait
RUN apk add --update --no-cache python3 && ln -sf python3 /usr/bin/python
RUN python3 -m ensurepip
RUN pip3 install --no-cache --upgrade pip setuptools pytest pyyaml mistletoe psycopg[binary]
RUN pip3 install --no-cache --upgrade pip setuptools pytest pyyaml mistletoe psycopg[binary] ruamel.yaml antlr4-python3-runtime pytz

WORKDIR /bft
COPY . .
Expand Down
4 changes: 3 additions & 1 deletion ci/docker/sqlite.Dockerfile
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
FROM alpine:3.18

ENV PYTHONUNBUFFERED=1
ENV PYTHONPATH=/bft/substrait
RUN apk add --update --no-cache python3 && ln -sf python3 /usr/bin/python
RUN python3 -m ensurepip
RUN pip3 install --no-cache --upgrade pip setuptools pytest pyyaml mistletoe
RUN pip3 install --no-cache --upgrade pip setuptools pytest pyyaml mistletoe ruamel.yaml antlr4-python3-runtime pytz

WORKDIR /bft
COPY . .

# CMD to run all commands and display the results
CMD /usr/bin/python -mpytest bft/tests/test_sqlite.py
3 changes: 2 additions & 1 deletion ci/docker/velox.Dockerfile
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
FROM ubuntu:22.04

ENV PYTHONUNBUFFERED=1
ENV PYTHONPATH=/bft/substrait
RUN apt-get update && apt-get install -y \
python3 \
python3-pip
RUN pip3 install --no-cache --upgrade pip setuptools pytest pyyaml mistletoe pyvelox
RUN pip3 install --no-cache --upgrade pip setuptools pytest pyyaml mistletoe pyvelox ruamel.yaml antlr4-python3-runtime pytz

WORKDIR /bft
COPY . .
Expand Down
1 change: 0 additions & 1 deletion dialects/datafusion.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -441,7 +441,6 @@ scalar_functions:
required_options:
lookaround: false
supported_kernels:
- str_str_str_i64_i64
- vchar_vchar_vchar_i64_i64
- str_str_str
- name: string.bit_length
Expand Down
2 changes: 0 additions & 2 deletions dialects/duckdb.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -498,7 +498,6 @@ scalar_functions:
required_options:
lookaround: false
supported_kernels:
- str_str_str_i64_i64
- vchar_vchar_vchar_i64_i64
- str_str_str
- name: string.regexp_string_split
Expand All @@ -514,7 +513,6 @@ scalar_functions:
lookaround: false
supported_kernels:
- vchar_vchar_i64_i64_i64
- str_str_i64_i64_i64
- str_str
- name: string.bit_length
supported_kernels:
Expand Down
3 changes: 0 additions & 3 deletions dialects/postgres.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -530,7 +530,6 @@ scalar_functions:
required_options:
lookaround: true
supported_kernels:
- str_str_str_i64_i64
- vchar_vchar_vchar_i64_i64
- str_str_str
- name: string.regexp_string_split
Expand All @@ -541,15 +540,13 @@ scalar_functions:
- name: string.regexp_count_substring
local_name: regexp_count
supported_kernels:
- str_str_i64
- vchar_vchar_i64
- fchar_fchar_i64
- str_str
- name: string.regexp_match_substring
local_name: regexp_substr
supported_kernels:
- vchar_vchar_i64_i64_i64
- str_str_i64_i64_i64
- str_str
- name: string.bit_length
supported_kernels:
Expand Down
2 changes: 0 additions & 2 deletions dialects/velox_presto.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -329,14 +329,12 @@ scalar_functions:
required_options:
spaces_only: true
supported_kernels:
- vchar_vchar
- str_str
- str
- name: string.rtrim
required_options:
spaces_only: true
supported_kernels:
- vchar_vchar
- str_str
- str
- name: string.trim
Expand Down
Loading

0 comments on commit e3cb637

Please sign in to comment.