Skip to content

Commit

Permalink
fix(spark): Custom annotation for SUBSTRING() (#4004)
Browse files Browse the repository at this point in the history
* fix(spark): Custom annotation for SUBSTRING()

* Remove empty line

* PR Feedback 1
  • Loading branch information
VaggelisD authored Aug 29, 2024
1 parent fcaae87 commit 4b7ca2b
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 0 deletions.
5 changes: 5 additions & 0 deletions sqlglot/dialects/spark2.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,11 @@ def temporary_storage_provider(expression: exp.Expression) -> exp.Expression:


class Spark2(Hive):
ANNOTATORS = {
**Hive.ANNOTATORS,
exp.Substring: lambda self, e: self._annotate_by_args(e, "this"),
}

class Parser(Hive.Parser):
TRIM_PATTERN_FIRST = True

Expand Down
23 changes: 23 additions & 0 deletions tests/test_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1345,3 +1345,26 @@ def gen_expr(depth: int) -> exp.Expression:
self.assertEqual(4, normalization_distance(gen_expr(2), max_=100))
self.assertEqual(18, normalization_distance(gen_expr(3), max_=100))
self.assertEqual(110, normalization_distance(gen_expr(10), max_=100))

def test_custom_annotators(self):
# In Spark hierarchy, SUBSTRING result type is dependent on input expr type
for dialect in ("spark2", "spark", "databricks"):
for expr_type_pair in (
("col", "STRING"),
("col", "BINARY"),
("'str_literal'", "STRING"),
("CAST('str_literal' AS BINARY)", "BINARY"),
):
with self.subTest(
f"Testing {dialect}'s SUBSTRING() result type for {expr_type_pair}"
):
expr, type = expr_type_pair
ast = parse_one(f"SELECT substring({expr}, 2, 3) AS x FROM tbl", read=dialect)

subst_type = (
optimizer.optimize(ast, schema={"tbl": {"col": type}}, dialect=dialect)
.expressions[0]
.type
)

self.assertEqual(subst_type.sql(dialect), exp.DataType.build(type).sql(dialect))

0 comments on commit 4b7ca2b

Please sign in to comment.