Skip to content

Commit

Permalink
Improve Snowflake array and map support (#780)
Browse files Browse the repository at this point in the history
* improve sf array and map support

* add parser support object_construct

* move to snowflake parser
  • Loading branch information
eakmanrq committed Nov 29, 2022
1 parent 26b1da1 commit 0506657
Show file tree
Hide file tree
Showing 6 changed files with 29 additions and 3 deletions.
6 changes: 3 additions & 3 deletions sqlglot/dialects/dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,19 +289,19 @@ def struct_extract_sql(self, expression):
return f"{this}.{struct_key}"


def var_map_sql(self, expression):
def var_map_sql(self, expression, map_func_name="MAP"):
keys = expression.args["keys"]
values = expression.args["values"]

if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array):
self.unsupported("Cannot convert array columns into map.")
return f"MAP({self.format_args(keys, values)})"
return f"{map_func_name}({self.format_args(keys, values)})"

args = []
for key, value in zip(keys.expressions, values.expressions):
args.append(self.sql(key))
args.append(self.sql(value))
return f"MAP({self.format_args(*args)})"
return f"{map_func_name}({self.format_args(*args)})"


def format_time_lambda(exp_class, dialect, default=None):
Expand Down
13 changes: 13 additions & 0 deletions sqlglot/dialects/snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
format_time_lambda,
inline_array_sql,
rename_func,
var_map_sql,
)
from sqlglot.expressions import Literal
from sqlglot.helper import seq_get
Expand Down Expand Up @@ -100,6 +101,14 @@ def _parse_date_part(self):
return self.expression(exp.Extract, this=this, expression=expression)


def _datatype_sql(self, expression):
if expression.this == exp.DataType.Type.ARRAY:
return "ARRAY"
elif expression.this == exp.DataType.Type.MAP:
return "OBJECT"
return self.datatype_sql(expression)


class Snowflake(Dialect):
null_ordering = "nulls_are_large"
time_format = "'yyyy-mm-dd hh24:mi:ss'"
Expand Down Expand Up @@ -143,6 +152,7 @@ class Parser(parser.Parser):
"ARRAY_CONSTRUCT": exp.Array.from_arg_list,
"RLIKE": exp.RegexpLike.from_arg_list,
"DECODE": exp.Matches.from_arg_list,
"OBJECT_CONSTRUCT": parser.parse_var_map,
}

FUNCTION_PARSERS = {
Expand Down Expand Up @@ -198,7 +208,10 @@ class Generator(generator.Generator):
**generator.Generator.TRANSFORMS,
exp.Array: inline_array_sql,
exp.ArrayConcat: rename_func("ARRAY_CAT"),
exp.DataType: _datatype_sql,
exp.If: rename_func("IFF"),
exp.Map: lambda self, e: var_map_sql(self, e, "OBJECT_CONSTRUCT"),
exp.VarMap: lambda self, e: var_map_sql(self, e, "OBJECT_CONSTRUCT"),
exp.Parameter: lambda self, e: f"${self.sql(e, 'this')}",
exp.PartitionedByProperty: lambda self, e: f"PARTITION BY {self.sql(e, 'this')}",
exp.Matches: rename_func("DECODE"),
Expand Down
1 change: 1 addition & 0 deletions tests/dialects/test_duckdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ def test_duckdb(self):
"hive": "CAST(COL AS ARRAY<BIGINT>)",
"spark": "CAST(COL AS ARRAY<LONG>)",
"postgres": "CAST(COL AS BIGINT[])",
"snowflake": "CAST(COL AS ARRAY)",
},
)

Expand Down
3 changes: 3 additions & 0 deletions tests/dialects/test_hive.py
Original file line number Diff line number Diff line change
Expand Up @@ -459,6 +459,7 @@ def test_hive(self):
"hive": "MAP(a, b, c, d)",
"presto": "MAP(ARRAY[a, c], ARRAY[b, d])",
"spark": "MAP(a, b, c, d)",
"snowflake": "OBJECT_CONSTRUCT(a, b, c, d)",
},
write={
"": "MAP(ARRAY(a, c), ARRAY(b, d))",
Expand All @@ -467,6 +468,7 @@ def test_hive(self):
"presto": "MAP(ARRAY[a, c], ARRAY[b, d])",
"hive": "MAP(a, b, c, d)",
"spark": "MAP(a, b, c, d)",
"snowflake": "OBJECT_CONSTRUCT(a, b, c, d)",
},
)
self.validate_all(
Expand All @@ -476,6 +478,7 @@ def test_hive(self):
"presto": "MAP(ARRAY[a], ARRAY[b])",
"hive": "MAP(a, b)",
"spark": "MAP(a, b)",
"snowflake": "OBJECT_CONSTRUCT(a, b)",
},
)
self.validate_all(
Expand Down
7 changes: 7 additions & 0 deletions tests/dialects/test_presto.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ def test_cast(self):
"duckdb": "CAST(a AS INT[])",
"presto": "CAST(a AS ARRAY(INTEGER))",
"spark": "CAST(a AS ARRAY<INT>)",
"snowflake": "CAST(a AS ARRAY)",
},
)
self.validate_all(
Expand All @@ -31,6 +32,7 @@ def test_cast(self):
"duckdb": "CAST(LIST_VALUE(1, 2) AS BIGINT[])",
"presto": "CAST(ARRAY[1, 2] AS ARRAY(BIGINT))",
"spark": "CAST(ARRAY(1, 2) AS ARRAY<LONG>)",
"snowflake": "CAST([1, 2] AS ARRAY)",
},
)
self.validate_all(
Expand All @@ -41,6 +43,7 @@ def test_cast(self):
"presto": "CAST(MAP(ARRAY[1], ARRAY[1]) AS MAP(INTEGER, INTEGER))",
"hive": "CAST(MAP(1, 1) AS MAP<INT, INT>)",
"spark": "CAST(MAP_FROM_ARRAYS(ARRAY(1), ARRAY(1)) AS MAP<INT, INT>)",
"snowflake": "CAST(OBJECT_CONSTRUCT(1, 1) AS OBJECT)",
},
)
self.validate_all(
Expand All @@ -51,6 +54,7 @@ def test_cast(self):
"presto": "CAST(MAP(ARRAY['a', 'b', 'c'], ARRAY[ARRAY[1], ARRAY[2], ARRAY[3]]) AS MAP(VARCHAR, ARRAY(INTEGER)))",
"hive": "CAST(MAP('a', ARRAY(1), 'b', ARRAY(2), 'c', ARRAY(3)) AS MAP<STRING, ARRAY<INT>>)",
"spark": "CAST(MAP_FROM_ARRAYS(ARRAY('a', 'b', 'c'), ARRAY(ARRAY(1), ARRAY(2), ARRAY(3))) AS MAP<STRING, ARRAY<INT>>)",
"snowflake": "CAST(OBJECT_CONSTRUCT('a', [1], 'b', [2], 'c', [3]) AS OBJECT)",
},
)
self.validate_all(
Expand Down Expand Up @@ -393,6 +397,7 @@ def test_presto(self):
write={
"hive": UnsupportedError,
"spark": "MAP_FROM_ARRAYS(a, b)",
"snowflake": UnsupportedError,
},
)
self.validate_all(
Expand All @@ -401,6 +406,7 @@ def test_presto(self):
"hive": "MAP(a, c, b, d)",
"presto": "MAP(ARRAY[a, b], ARRAY[c, d])",
"spark": "MAP_FROM_ARRAYS(ARRAY(a, b), ARRAY(c, d))",
"snowflake": "OBJECT_CONSTRUCT(a, c, b, d)",
},
)
self.validate_all(
Expand All @@ -409,6 +415,7 @@ def test_presto(self):
"hive": "MAP('a', 'b')",
"presto": "MAP(ARRAY['a'], ARRAY['b'])",
"spark": "MAP_FROM_ARRAYS(ARRAY('a'), ARRAY('b'))",
"snowflake": "OBJECT_CONSTRUCT('a', 'b')",
},
)
self.validate_all(
Expand Down
2 changes: 2 additions & 0 deletions tests/dialects/test_spark.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def test_ddl(self):
"presto": "CREATE TABLE db.example_table (col_a ARRAY(INTEGER), col_b ARRAY(ARRAY(INTEGER)))",
"hive": "CREATE TABLE db.example_table (col_a ARRAY<INT>, col_b ARRAY<ARRAY<INT>>)",
"spark": "CREATE TABLE db.example_table (col_a ARRAY<INT>, col_b ARRAY<ARRAY<INT>>)",
"snowflake": "CREATE TABLE db.example_table (col_a ARRAY, col_b ARRAY)",
},
)
self.validate_all(
Expand Down Expand Up @@ -278,6 +279,7 @@ def test_spark(self):
"presto": "MAP(ARRAY[1], c)",
"hive": "MAP(ARRAY(1), c)",
"spark": "MAP_FROM_ARRAYS(ARRAY(1), c)",
"snowflake": "OBJECT_CONSTRUCT([1], c)",
},
)
self.validate_all(
Expand Down

0 comments on commit 0506657

Please sign in to comment.