Skip to content

Commit

Permalink
feat(flink): fine-tune numeric literal translation
Browse files Browse the repository at this point in the history
  • Loading branch information
deepyaman authored and jcrist committed Aug 29, 2023
1 parent c4ec529 commit 2f2d0d9
Show file tree
Hide file tree
Showing 6 changed files with 34 additions and 7 deletions.
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
SELECT t0.*
FROM table t0
WHERE ((t0.`c` > 0) OR (t0.`c` < 0)) AND
WHERE ((t0.`c` > CAST(0 AS TINYINT)) OR (t0.`c` < CAST(0 AS TINYINT))) AND
(t0.`g` IN ('A', 'B'))
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
SELECT t0.`g`, sum(t0.`b`) AS `b_sum`
FROM table t0
GROUP BY t0.`g`
HAVING count(*) >= 1000
HAVING count(*) >= CAST(1000 AS SMALLINT)
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
SELECT sum(t0.`f`) OVER (ORDER BY t0.`f` ASC ROWS BETWEEN 1000 PRECEDING AND CURRENT ROW) AS `Sum(f)`
SELECT sum(t0.`f`) OVER (ORDER BY t0.`f` ASC ROWS BETWEEN CAST(1000 AS SMALLINT) PRECEDING AND CURRENT ROW) AS `Sum(f)`
FROM table t0
4 changes: 2 additions & 2 deletions ibis/backends/flink/tests/test_literals.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@
@pytest.mark.parametrize(
"value,expected",
[
param(5, "5", id="int"),
param(1.5, "1.5", id="float"),
param(5, "CAST(5 AS TINYINT)", id="int"),
param(1.5, "CAST(1.5 AS DOUBLE)", id="float"),
param(True, "TRUE", id="true"),
param(False, "FALSE", id="false"),
],
Expand Down
20 changes: 18 additions & 2 deletions ibis/backends/flink/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
from abc import ABC, abstractmethod
from collections import defaultdict

from pyflink.table.types import DataTypes

import ibis.expr.datatypes as dt
import ibis.expr.operations as ops
from ibis.common.temporal import IntervalUnit
Expand All @@ -31,7 +33,6 @@
IntervalUnit.SECOND: 9,
}


MICROSECONDS_IN_UNIT = {
unit: datetime.timedelta(**{unit.plural: 1}).total_seconds() * 10**6
for unit in [
Expand Down Expand Up @@ -244,6 +245,21 @@ def _translate_interval(value, dtype):
return interval.format_as_string()


_to_pyflink_types = {
dt.Int8: DataTypes.TINYINT(),
dt.Int16: DataTypes.SMALLINT(),
dt.Int32: DataTypes.INT(),
dt.Int64: DataTypes.BIGINT(),
dt.UInt8: DataTypes.TINYINT(),
dt.UInt16: DataTypes.SMALLINT(),
dt.UInt32: DataTypes.INT(),
dt.UInt64: DataTypes.BIGINT(),
dt.Float16: DataTypes.FLOAT(),
dt.Float32: DataTypes.FLOAT(),
dt.Float64: DataTypes.DOUBLE(),
}


def translate_literal(op: ops.Literal) -> str:
value = op.value
dtype = op.dtype
Expand All @@ -266,7 +282,7 @@ def translate_literal(op: ops.Literal) -> str:
raise ValueError("NaN is not supported in Flink SQL")
elif math.isinf(value):
raise ValueError("Infinity is not supported in Flink SQL")
return repr(value)
return f"CAST({value} AS {_to_pyflink_types[type(dtype)]!s})"
elif dtype.is_timestamp():
# TODO(chloeh13q): support timestamp with local timezone
if isinstance(value, datetime.datetime):
Expand Down
11 changes: 11 additions & 0 deletions ibis/backends/tests/test_numeric.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@
"trino": "integer",
"duckdb": "TINYINT",
"postgres": "integer",
"flink": "TINYINT NOT NULL",
},
id="int8",
),
Expand All @@ -85,6 +86,7 @@
"trino": "integer",
"duckdb": "SMALLINT",
"postgres": "integer",
"flink": "SMALLINT NOT NULL",
},
id="int16",
),
Expand All @@ -99,6 +101,7 @@
"trino": "integer",
"duckdb": "INTEGER",
"postgres": "integer",
"flink": "INT NOT NULL",
},
id="int32",
),
Expand All @@ -113,6 +116,7 @@
"trino": "integer",
"duckdb": "BIGINT",
"postgres": "integer",
"flink": "BIGINT NOT NULL",
},
id="int64",
),
Expand All @@ -127,6 +131,7 @@
"trino": "integer",
"duckdb": "UTINYINT",
"postgres": "integer",
"flink": "TINYINT NOT NULL",
},
id="uint8",
),
Expand All @@ -141,6 +146,7 @@
"trino": "integer",
"duckdb": "USMALLINT",
"postgres": "integer",
"flink": "SMALLINT NOT NULL",
},
id="uint16",
),
Expand All @@ -155,6 +161,7 @@
"trino": "integer",
"duckdb": "UINTEGER",
"postgres": "integer",
"flink": "INT NOT NULL",
},
id="uint32",
),
Expand All @@ -169,6 +176,7 @@
"trino": "integer",
"duckdb": "UBIGINT",
"postgres": "integer",
"flink": "BIGINT NOT NULL",
},
id="uint64",
),
Expand All @@ -183,6 +191,7 @@
"trino": "double",
"duckdb": "FLOAT",
"postgres": "numeric",
"flink": "FLOAT NOT NULL",
},
marks=[
pytest.mark.notimpl(
Expand All @@ -209,6 +218,7 @@
"trino": "double",
"duckdb": "FLOAT",
"postgres": "numeric",
"flink": "FLOAT NOT NULL",
},
id="float32",
),
Expand All @@ -223,6 +233,7 @@
"trino": "double",
"duckdb": "DOUBLE",
"postgres": "numeric",
"flink": "DOUBLE NOT NULL",
},
id="float64",
),
Expand Down

0 comments on commit 2f2d0d9

Please sign in to comment.