Skip to content

Commit

Permalink
fix(deps): update dependency sqlglot to >=23.4,<23.10 (#8787)
Browse files Browse the repository at this point in the history
Co-authored-by: renovate[bot] <29139614+renovate[bot]@users.noreply.github.com>
Co-authored-by: Phillip Cloud <417981+cpcloud@users.noreply.github.com>
Co-authored-by: Gil Forsyth <gil@forsyth.dev>
  • Loading branch information
3 people authored Apr 13, 2024
1 parent 9f47670 commit 0f00101
Show file tree
Hide file tree
Showing 19 changed files with 57 additions and 96 deletions.
25 changes: 5 additions & 20 deletions ibis/backends/bigquery/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,22 +293,7 @@ def visti_StringFind(self, op, *, arg, substr, start, end):
return self.f.strpos(arg, substr)

def visit_NonNullLiteral(self, op, *, value, dtype):
if dtype.is_string():
return sge.convert(
str(value)
# Escape \ first so we don't double escape other characters.
.replace("\\", "\\\\")
# ASCII escape sequences that are recognized in Python:
# https://docs.python.org/3/reference/lexical_analysis.html#string-and-bytes-literals
.replace("\a", "\\a") # Bell
.replace("\b", "\\b") # Backspace
.replace("\f", "\\f") # Formfeed
.replace("\n", "\\n") # Newline / Linefeed
.replace("\r", "\\r") # Carriage return
.replace("\t", "\\t") # Tab
.replace("\v", "\\v") # Vertical tab
)
elif dtype.is_inet() or dtype.is_macaddr():
if dtype.is_inet() or dtype.is_macaddr():
return sge.convert(str(value))
elif dtype.is_timestamp():
funcname = "datetime" if dtype.timezone is None else "timestamp"
Expand Down Expand Up @@ -570,10 +555,10 @@ def visit_RegexExtract(self, op, *, arg, pattern, index):
nonzero_index_replace = self.f.regexp_replace(
arg,
self.f.concat(".*?", pattern, ".*"),
self.f.concat("\\\\", self.cast(index, dt.string)),
self.f.concat("\\", self.cast(index, dt.string)),
)
zero_index_replace = self.f.regexp_replace(
arg, self.f.concat(".*?", self.f.concat("(", pattern, ")"), ".*"), "\\\\1"
arg, self.f.concat(".*?", self.f.concat("(", pattern, ")"), ".*"), "\\1"
)
extract = self.if_(index.eq(0), zero_index_replace, nonzero_index_replace)
return self.if_(matches, extract, NULL)
Expand Down Expand Up @@ -653,7 +638,7 @@ def visit_TypeOf(self, op, *, arg):
self.if_(self.f.regexp_contains(name, "^-?[0-9]*$"), "INT64"),
self.if_(
self.f.regexp_contains(
name, r'^(-?[0-9]+[.e].*|CAST\\("([^"]*)" AS FLOAT64\\))$'
name, r'^(-?[0-9]+[.e].*|CAST\("([^"]*)" AS FLOAT64\))$'
),
"FLOAT64",
),
Expand All @@ -664,7 +649,7 @@ def visit_TypeOf(self, op, *, arg):
),
self.if_(self.f.starts_with(name, 'b"'), "BYTES"),
self.if_(self.f.starts_with(name, "["), "ARRAY"),
self.if_(self.f.regexp_contains(name, r"^(STRUCT)?\\("), "STRUCT"),
self.if_(self.f.regexp_contains(name, r"^(STRUCT)?\("), "STRUCT"),
self.if_(self.f.starts_with(name, "ST_"), "GEOGRAPHY"),
self.if_(name.eq(sge.convert("NULL")), "NULL"),
]
Expand Down
2 changes: 0 additions & 2 deletions ibis/backends/clickhouse/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,8 +280,6 @@ def visit_NonNullLiteral(self, op, *, value, dtype):
if dtype.is_inet():
v = str(value)
return self.f.toIPv6(v) if ":" in v else self.f.toIPv4(v)
elif dtype.is_string():
return sge.convert(str(value).replace(r"\0", r"\\0"))
elif dtype.is_decimal():
precision = dtype.precision
if precision is None or not 1 <= precision <= 76:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
SELECT
CASE
WHEN notEmpty(
extractGroups(CAST("t0"."string_col" AS String), CONCAT('(', '[\d]+', ')'))[3 + 1]
extractGroups(CAST("t0"."string_col" AS String), CONCAT('(', '[\\d]+', ')'))[3 + 1]
)
THEN extractGroups(CAST("t0"."string_col" AS String), CONCAT('(', '[\d]+', ')'))[3 + 1]
THEN extractGroups(CAST("t0"."string_col" AS String), CONCAT('(', '[\\d]+', ')'))[3 + 1]
ELSE NULL
END AS "RegexExtract(string_col, '[\\d]+', 3)"
FROM "functional_alltypes" AS "t0"
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
SELECT
replaceRegexpAll("t0"."string_col", '[\d]+', 'aaa') AS "RegexReplace(string_col, '[\\d]+', 'aaa')"
replaceRegexpAll("t0"."string_col", '[\\d]+', 'aaa') AS "RegexReplace(string_col, '[\\d]+', 'aaa')"
FROM "functional_alltypes" AS "t0"
59 changes: 26 additions & 33 deletions ibis/backends/flink/tests/test_compiler.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,20 @@
from __future__ import annotations

from operator import methodcaller

import pytest
from pytest import param

import ibis
from ibis.common.deferred import _


def test_sum(con, simple_table, assert_sql):
def test_sum(simple_table, assert_sql):
expr = simple_table.a.sum()
assert_sql(expr)


def test_count_star(con, simple_table, assert_sql):
def test_count_star(simple_table, assert_sql):
expr = simple_table.group_by(simple_table.i).size()
assert_sql(expr)

Expand All @@ -24,12 +26,12 @@ def test_count_star(con, simple_table, assert_sql):
param("s", id="timestamp_s"),
],
)
def test_timestamp_from_unix(con, simple_table, unit, assert_sql):
def test_timestamp_from_unix(simple_table, unit, assert_sql):
expr = simple_table.d.to_timestamp(unit=unit)
assert_sql(expr)


def test_complex_projections(con, simple_table, assert_sql):
def test_complex_projections(simple_table, assert_sql):
expr = (
simple_table.group_by(["a", "c"])
.aggregate(the_sum=simple_table.b.sum())
Expand All @@ -39,7 +41,7 @@ def test_complex_projections(con, simple_table, assert_sql):
assert_sql(expr)


def test_filter(con, simple_table, assert_sql):
def test_filter(simple_table, assert_sql):
expr = simple_table[
((simple_table.c > 0) | (simple_table.c < 0)) & simple_table.g.isin(["A", "B"])
]
Expand All @@ -60,12 +62,12 @@ def test_filter(con, simple_table, assert_sql):
"second",
],
)
def test_extract_fields(con, simple_table, kind, assert_sql):
def test_extract_fields(simple_table, kind, assert_sql):
expr = getattr(simple_table.i, kind)().name("tmp")
assert_sql(expr)


def test_complex_groupby_aggregation(con, simple_table, assert_sql):
def test_complex_groupby_aggregation(simple_table, assert_sql):
keys = [simple_table.i.year().name("year"), simple_table.i.month().name("month")]
b_unique = simple_table.b.nunique()
expr = simple_table.group_by(keys).aggregate(
Expand All @@ -74,12 +76,12 @@ def test_complex_groupby_aggregation(con, simple_table, assert_sql):
assert_sql(expr)


def test_simple_filtered_agg(con, simple_table, assert_sql):
def test_simple_filtered_agg(simple_table, assert_sql):
expr = simple_table.b.nunique(where=simple_table.g == "A")
assert_sql(expr)


def test_complex_filtered_agg(con, snapshot, simple_table, assert_sql):
def test_complex_filtered_agg(simple_table, assert_sql):
expr = simple_table.group_by("b").aggregate(
total=simple_table.count(),
avg_a=simple_table.a.mean(),
Expand All @@ -89,12 +91,12 @@ def test_complex_filtered_agg(con, snapshot, simple_table, assert_sql):
assert_sql(expr)


def test_value_counts(con, simple_table, assert_sql):
def test_value_counts(simple_table, assert_sql):
expr = simple_table.i.year().value_counts()
assert_sql(expr)


def test_having(con, simple_table, assert_sql):
def test_having(simple_table, assert_sql):
expr = (
simple_table.group_by("g")
.having(simple_table.count() >= 1000)
Expand All @@ -104,37 +106,28 @@ def test_having(con, simple_table, assert_sql):


@pytest.mark.parametrize(
"function_type,params",
"method",
[
pytest.param(
"tumble", {"window_size": ibis.interval(minutes=15)}, id="tumble_window"
),
pytest.param(
methodcaller("tumble", window_size=ibis.interval(minutes=15)),
methodcaller(
"hop",
{
"window_size": ibis.interval(minutes=15),
"window_slide": ibis.interval(minutes=1),
},
id="hop_window",
window_size=ibis.interval(minutes=15),
window_slide=ibis.interval(minutes=1),
),
pytest.param(
methodcaller(
"cumulate",
{
"window_size": ibis.interval(minutes=1),
"window_step": ibis.interval(seconds=10),
},
id="cumulate_window",
window_size=ibis.interval(minutes=1),
window_step=ibis.interval(seconds=10),
),
],
ids=["tumble", "hop", "cumulate"],
)
def test_windowing_tvf(con, simple_table, function_type, params, assert_sql):
expr = getattr(simple_table.window_by(time_col=simple_table.i), function_type)(
**params
)
def test_windowing_tvf(simple_table, method, assert_sql):
expr = method(simple_table.window_by(time_col=simple_table.i))
assert_sql(expr)


def test_window_aggregation(con, simple_table, assert_sql):
def test_window_aggregation(simple_table, assert_sql):
expr = (
simple_table.window_by(time_col=simple_table.i)
.tumble(window_size=ibis.interval(minutes=15))
Expand All @@ -144,7 +137,7 @@ def test_window_aggregation(con, simple_table, assert_sql):
assert_sql(expr)


def test_window_topn(con, simple_table, assert_sql):
def test_window_topn(simple_table, assert_sql):
expr = simple_table.window_by(time_col="i").tumble(
window_size=ibis.interval(seconds=600),
)["a", "b", "c", "d", "g", "window_start", "window_end"]
Expand Down
16 changes: 0 additions & 16 deletions ibis/backends/impala/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,22 +195,6 @@ def visit_NonNullLiteral(self, op, *, value, dtype):
# supported, but only within a certain range, and the
# implementation wraps on over- and underflow
return sge.convert(value.isoformat())
elif dtype.is_string():
value = (
value
# Escape \ first so we don't double escape other characters.
.replace("\\", "\\\\")
# ASCII escape sequences that are recognized in Python:
# https://docs.python.org/3/reference/lexical_analysis.html#string-and-bytes-literals
.replace("\a", "\\a") # Bell
.replace("\b", "\\b") # Backspace
.replace("\f", "\\f") # Formfeed
.replace("\n", "\\n") # Newline / Linefeed
.replace("\r", "\\r") # Carriage return
.replace("\t", "\\t") # Tab
.replace("\v", "\\v") # Vertical tab
)
return sge.convert(value)
elif dtype.is_decimal() and not value.is_finite():
raise com.UnsupportedOperationError(
f"Non-finite decimal literal values are not supported by Impala; got: {value}"
Expand Down
6 changes: 2 additions & 4 deletions ibis/backends/mysql/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,8 +217,6 @@ def visit_NonNullLiteral(self, op, *, value, dtype):
raise com.UnsupportedBackendType(
"MySQL does not support arrays, structs or maps"
)
elif dtype.is_string():
return sge.convert(value.replace("\\", "\\\\"))
return None

def visit_JSONGetItem(self, op, *, arg, index):
Expand Down Expand Up @@ -260,7 +258,7 @@ def visit_RegexExtract(self, op, *, arg, pattern, index):
index.eq(0),
extracted,
self.f.regexp_replace(
extracted, pattern, rf"\\{index.sql(self.dialect)}"
extracted, pattern, f"\\{index.sql(self.dialect)}"
),
),
NULL,
Expand Down Expand Up @@ -336,7 +334,7 @@ def visit_RStrip(self, op, *, arg):
return self.visit_LRStrip(op, arg=arg, position="TRAILING")

def visit_IntervalFromInteger(self, op, *, arg, unit):
return sge.Interval(this=arg, unit=sge.convert(op.resolution.upper()))
return sge.Interval(this=arg, unit=sge.Var(this=op.resolution.upper()))

def visit_TimestampAdd(self, op, *, left, right):
if op.right.dtype.unit.short == "ms":
Expand Down
2 changes: 1 addition & 1 deletion ibis/backends/postgres/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,7 +469,7 @@ def visit_ExtractEpochSeconds(self, op, *, arg):

def visit_ArrayIndex(self, op, *, arg, index):
index = self.if_(index < 0, self.f.cardinality(arg) + index, index)
return sge.paren(arg, copy=False)[index + 1]
return sge.paren(arg, copy=False)[index]

def visit_ArraySlice(self, op, *, arg, start, stop):
neg_to_pos_index = lambda n, index: self.if_(index < 0, n + index, index)
Expand Down
3 changes: 0 additions & 3 deletions ibis/backends/pyspark/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,9 +92,6 @@ def visit_NonNullLiteral(self, op, *, value, dtype):
return self.f.nanvl(result, NULL)
else:
return result
elif dtype.is_string():
value = value.replace("\\", "\\\\")
return sge.convert(value)
elif dtype.is_binary():
return self.f.unhex(value.hex())
elif dtype.is_decimal():
Expand Down
3 changes: 0 additions & 3 deletions ibis/backends/snowflake/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,9 +109,6 @@ def _minimize_spec(start, end, spec):
def visit_Literal(self, op, *, value, dtype):
if value is None:
return super().visit_Literal(op, value=value, dtype=dtype)
elif dtype.is_string():
# sqlglot doesn't escape backslashes in strings
return sge.convert(value.replace("\\", "\\\\"))
elif dtype.is_timestamp():
args = (
value.year,
Expand Down
7 changes: 5 additions & 2 deletions ibis/backends/sql/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -586,7 +586,8 @@ def visit_DefaultLiteral(self, op, *, value, dtype):
return self.cast(str(value), dtype)
elif dtype.is_interval():
return sge.Interval(
this=sge.convert(str(value)), unit=dtype.resolution.upper()
this=sge.convert(str(value)),
unit=sge.Var(this=dtype.resolution.upper()),
)
elif dtype.is_boolean():
return sge.Boolean(this=bool(value))
Expand Down Expand Up @@ -788,7 +789,9 @@ def visit_DayOfWeekName(self, op, *, arg):
)

def visit_IntervalFromInteger(self, op, *, arg, unit):
return sge.Interval(this=sge.convert(arg), unit=unit.singular.upper())
return sge.Interval(
this=sge.convert(arg), unit=sge.Var(this=unit.singular.upper())
)

### String Instruments

Expand Down
6 changes: 5 additions & 1 deletion ibis/backends/sql/datatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ def to_ibis(cls, typ: sge.DataType, nullable: bool | None = None) -> dt.DataType
if isinstance(typecode, sge.Interval):
typ = sge.DataType(
this=sge.DataType.Type.INTERVAL,
expressions=[sge.IntervalSpan(this=typecode.unit)],
expressions=[typecode.unit],
)
typecode = typ.this

Expand Down Expand Up @@ -731,6 +731,10 @@ def _from_sqlglot_DATETIME(cls) -> dt.Timestamp:

@classmethod
def _from_sqlglot_TIMESTAMP(cls) -> dt.Timestamp:
return dt.Timestamp(timezone=None, nullable=cls.default_nullable)

@classmethod
def _from_sqlglot_TIMESTAMPTZ(cls) -> dt.Timestamp:
return dt.Timestamp(timezone="UTC", nullable=cls.default_nullable)

@classmethod
Expand Down
4 changes: 3 additions & 1 deletion ibis/backends/sql/dialects.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def _interval_with_precision(self, e):
formatted_arg = f"'{formatted_arg}'"
prec = _calculate_precision(int(arg))
prec = max(prec, 2)
unit += f"({prec})"
unit.args["this"] += f"({prec})"

return f"INTERVAL {formatted_arg} {unit}"

Expand Down Expand Up @@ -182,6 +182,8 @@ def new_name(names: set[str], name: str) -> str:


class Flink(Hive):
UNESCAPED_SEQUENCES = {"\\\\d": "\\d"}

class Generator(Hive.Generator):
UNNEST_WITH_ORDINALITY = False

Expand Down
Loading

0 comments on commit 0f00101

Please sign in to comment.