Skip to content

Commit

Permalink
Refactor properties (#733)
Browse files Browse the repository at this point in the history
* Generator fixup

* Start refactoring the way we handle properties

* Get rid of 'this' for props, cleanup

* Compound key support for redshift SORTKEY

* PR feedback
  • Loading branch information
georgesittas committed Nov 19, 2022
1 parent 8a32dd2 commit be332d1
Show file tree
Hide file tree
Showing 16 changed files with 124 additions and 189 deletions.
14 changes: 6 additions & 8 deletions sqlglot/dialects/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,12 +56,12 @@ def _derived_table_values_to_unnest(self, expression):


def _returnsproperty_sql(self, expression):
value = expression.args.get("value")
if isinstance(value, exp.Schema):
value = f"{value.this} <{self.expressions(value)}>"
this = expression.this
if isinstance(this, exp.Schema):
this = f"{this.this} <{self.expressions(this)}>"
else:
value = self.sql(value)
return f"RETURNS {value}"
this = self.sql(this)
return f"RETURNS {this}"


def _create_sql(self, expression):
Expand Down Expand Up @@ -200,9 +200,7 @@ class Generator(generator.Generator):
exp.VolatilityProperty,
}

WITH_PROPERTIES = {
exp.AnonymousProperty,
}
WITH_PROPERTIES = {exp.Property}

EXPLICIT_UNION = True

Expand Down
15 changes: 5 additions & 10 deletions sqlglot/dialects/dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,18 +336,13 @@ def create_with_partitions_sql(self, expression):
if has_schema and is_partitionable:
expression = expression.copy()
prop = expression.find(exp.PartitionedByProperty)
value = prop and prop.args.get("value")
if prop and not isinstance(value, exp.Schema):
this = prop and prop.this
if prop and not isinstance(this, exp.Schema):
schema = expression.this
columns = {v.name.upper() for v in value.expressions}
columns = {v.name.upper() for v in this.expressions}
partitions = [col for col in schema.expressions if col.name.upper() in columns]
schema.set(
"expressions",
[e for e in schema.expressions if e not in partitions],
)
prop.replace(
exp.PartitionedByProperty(this=prop.this, value=exp.Schema(expressions=partitions))
)
schema.set("expressions", [e for e in schema.expressions if e not in partitions])
prop.replace(exp.PartitionedByProperty(this=exp.Schema(expressions=partitions)))
expression.set("this", schema)

return self.create_sql(expression)
Expand Down
2 changes: 1 addition & 1 deletion sqlglot/dialects/drill.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ class Generator(generator.Generator):
exp.If: if_sql,
exp.ILike: lambda self, e: f" {self.sql(e, 'this')} `ILIKE` {self.sql(e, 'expression')}",
exp.Levenshtein: rename_func("LEVENSHTEIN_DISTANCE"),
exp.PartitionedByProperty: lambda self, e: f"PARTITION BY {self.sql(e, 'value')}",
exp.PartitionedByProperty: lambda self, e: f"PARTITION BY {self.sql(e, 'this')}",
exp.Pivot: no_pivot_sql,
exp.RegexpLike: rename_func("REGEXP_MATCHES"),
exp.StrPosition: str_position_sql,
Expand Down
14 changes: 6 additions & 8 deletions sqlglot/dialects/hive.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,7 @@ def _array_sort(self, expression):


def _property_sql(self, expression):
key = expression.name
value = self.sql(expression, "value")
return f"'{key}'={value}"
return f"'{expression.name}'={self.sql(expression, 'value')}"


def _str_to_unix(self, expression):
Expand Down Expand Up @@ -250,7 +248,7 @@ class Generator(generator.Generator):
TRANSFORMS = {
**generator.Generator.TRANSFORMS,
**transforms.UNALIAS_GROUP, # type: ignore
exp.AnonymousProperty: _property_sql,
exp.Property: _property_sql,
exp.ApproxDistinct: approx_count_distinct_sql,
exp.ArrayAgg: rename_func("COLLECT_LIST"),
exp.ArrayConcat: rename_func("CONCAT"),
Expand All @@ -262,7 +260,7 @@ class Generator(generator.Generator):
exp.DateStrToDate: rename_func("TO_DATE"),
exp.DateToDi: lambda self, e: f"CAST(DATE_FORMAT({self.sql(e, 'this')}, {Hive.dateint_format}) AS INT)",
exp.DiToDate: lambda self, e: f"TO_DATE(CAST({self.sql(e, 'this')} AS STRING), {Hive.dateint_format})",
exp.FileFormatProperty: lambda self, e: f"STORED AS {e.text('value').upper()}",
exp.FileFormatProperty: lambda self, e: f"STORED AS {e.name.upper()}",
exp.If: if_sql,
exp.Index: _index_sql,
exp.ILike: no_ilike_sql,
Expand All @@ -285,7 +283,7 @@ class Generator(generator.Generator):
exp.StrToTime: _str_to_time,
exp.StrToUnix: _str_to_unix,
exp.StructExtract: struct_extract_sql,
exp.TableFormatProperty: lambda self, e: f"USING {self.sql(e, 'value')}",
exp.TableFormatProperty: lambda self, e: f"USING {self.sql(e, 'this')}",
exp.TimeStrToDate: rename_func("TO_DATE"),
exp.TimeStrToTime: lambda self, e: f"CAST({self.sql(e, 'this')} AS TIMESTAMP)",
exp.TimeStrToUnix: rename_func("UNIX_TIMESTAMP"),
Expand All @@ -298,11 +296,11 @@ class Generator(generator.Generator):
exp.UnixToStr: lambda self, e: f"FROM_UNIXTIME({self.format_args(e.this, _time_format(self, e))})",
exp.UnixToTime: rename_func("FROM_UNIXTIME"),
exp.UnixToTimeStr: rename_func("FROM_UNIXTIME"),
exp.PartitionedByProperty: lambda self, e: f"PARTITIONED BY {self.sql(e, 'value')}",
exp.PartitionedByProperty: lambda self, e: f"PARTITIONED BY {self.sql(e, 'this')}",
exp.NumberToStr: rename_func("FORMAT_NUMBER"),
}

WITH_PROPERTIES = {exp.AnonymousProperty}
WITH_PROPERTIES = {exp.Property}

ROOT_PROPERTIES = {
exp.PartitionedByProperty,
Expand Down
14 changes: 3 additions & 11 deletions sqlglot/dialects/presto.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,16 +171,7 @@ class Generator(generator.Generator):

STRUCT_DELIMITER = ("(", ")")

ROOT_PROPERTIES = {
exp.SchemaCommentProperty,
}

WITH_PROPERTIES = {
exp.PartitionedByProperty,
exp.FileFormatProperty,
exp.AnonymousProperty,
exp.TableFormatProperty,
}
ROOT_PROPERTIES = {exp.SchemaCommentProperty}

TYPE_MAPPING = {
**generator.Generator.TYPE_MAPPING,
Expand Down Expand Up @@ -231,7 +222,8 @@ class Generator(generator.Generator):
exp.StrToTime: _str_to_time_sql,
exp.StrToUnix: lambda self, e: f"TO_UNIXTIME(DATE_PARSE({self.sql(e, 'this')}, {self.format_time(e)}))",
exp.StructExtract: struct_extract_sql,
exp.TableFormatProperty: lambda self, e: f"TABLE_FORMAT = '{e.text('value').upper()}'",
exp.TableFormatProperty: lambda self, e: f"TABLE_FORMAT='{e.name.upper()}'",
exp.FileFormatProperty: lambda self, e: f"FORMAT='{e.name.upper()}'",
exp.TimeStrToDate: _date_parse_sql,
exp.TimeStrToTime: _date_parse_sql,
exp.TimeStrToUnix: lambda self, e: f"TO_UNIXTIME(DATE_PARSE({self.sql(e, 'this')}, {Presto.time_format}))",
Expand Down
13 changes: 13 additions & 0 deletions sqlglot/dialects/redshift.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,3 +35,16 @@ class Generator(Postgres.Generator):
exp.DataType.Type.VARBINARY: "VARBYTE",
exp.DataType.Type.INT: "INTEGER",
}

ROOT_PROPERTIES = {
exp.DistKeyProperty,
exp.SortKeyProperty,
exp.DistStyleProperty,
}

TRANSFORMS = {
**Postgres.Generator.TRANSFORMS, # type: ignore
exp.DistKeyProperty: lambda self, e: f"DISTKEY({e.name})",
exp.SortKeyProperty: lambda self, e: f"{'COMPOUND ' if e.args['compound'] else ''}SORTKEY({self.format_args(*e.this)})",
exp.DistStyleProperty: lambda self, e: self.naked_property(e),
}
2 changes: 1 addition & 1 deletion sqlglot/dialects/snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ class Generator(generator.Generator):
exp.Array: inline_array_sql,
exp.StrPosition: rename_func("POSITION"),
exp.Parameter: lambda self, e: f"${self.sql(e, 'this')}",
exp.PartitionedByProperty: lambda self, e: f"PARTITION BY {self.sql(e, 'value')}",
exp.PartitionedByProperty: lambda self, e: f"PARTITION BY {self.sql(e, 'this')}",
exp.Trim: lambda self, e: f"TRIM({self.format_args(e.this, e.expression)})",
}

Expand Down
2 changes: 1 addition & 1 deletion sqlglot/dialects/spark.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ class Generator(Hive.Generator):
TRANSFORMS = {
**Hive.Generator.TRANSFORMS, # type: ignore
exp.ApproxDistinct: rename_func("APPROX_COUNT_DISTINCT"),
exp.FileFormatProperty: lambda self, e: f"USING {e.text('value').upper()}",
exp.FileFormatProperty: lambda self, e: f"USING {e.name.upper()}",
exp.ArraySum: lambda self, e: f"AGGREGATE({self.sql(e, 'this')}, 0, (acc, x) -> acc + x, acc -> acc)",
exp.BitwiseLeftShift: rename_func("SHIFTLEFT"),
exp.BitwiseRightShift: rename_func("SHIFTRIGHT"),
Expand Down
57 changes: 31 additions & 26 deletions sqlglot/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1065,67 +1065,63 @@ class Property(Expression):


class TableFormatProperty(Property):
pass
arg_types = {"this": True}


class PartitionedByProperty(Property):
pass
arg_types = {"this": True}


class FileFormatProperty(Property):
pass
arg_types = {"this": True}


class DistKeyProperty(Property):
pass
arg_types = {"this": True}


class SortKeyProperty(Property):
pass
arg_types = {"this": True, "compound": False}


class DistStyleProperty(Property):
pass
arg_types = {"this": True}


class LocationProperty(Property):
pass
arg_types = {"this": True}


class EngineProperty(Property):
pass
arg_types = {"this": True}


class AutoIncrementProperty(Property):
pass
arg_types = {"this": True}


class CharacterSetProperty(Property):
arg_types = {"this": True, "value": True, "default": True}
arg_types = {"this": True, "default": True}


class CollateProperty(Property):
pass
arg_types = {"this": True}


class SchemaCommentProperty(Property):
pass


class AnonymousProperty(Property):
pass
arg_types = {"this": True}


class ReturnsProperty(Property):
arg_types = {"this": True, "value": True, "is_table": False}
arg_types = {"this": True, "is_table": False}


class LanguageProperty(Property):
pass
arg_types = {"this": True}


class ExecuteAsProperty(Property):
pass
arg_types = {"this": True}


class VolatilityProperty(Property):
Expand All @@ -1135,27 +1131,36 @@ class VolatilityProperty(Property):
class Properties(Expression):
arg_types = {"expressions": True}

PROPERTY_KEY_MAPPING = {
NAME_TO_PROPERTY = {
"AUTO_INCREMENT": AutoIncrementProperty,
"CHARACTER_SET": CharacterSetProperty,
"CHARACTER SET": CharacterSetProperty,
"COLLATE": CollateProperty,
"COMMENT": SchemaCommentProperty,
"DISTKEY": DistKeyProperty,
"DISTSTYLE": DistStyleProperty,
"ENGINE": EngineProperty,
"EXECUTE AS": ExecuteAsProperty,
"FORMAT": FileFormatProperty,
"LANGUAGE": LanguageProperty,
"LOCATION": LocationProperty,
"PARTITIONED_BY": PartitionedByProperty,
"TABLE_FORMAT": TableFormatProperty,
"DISTKEY": DistKeyProperty,
"DISTSTYLE": DistStyleProperty,
"RETURNS": ReturnsProperty,
"SORTKEY": SortKeyProperty,
"TABLE_FORMAT": TableFormatProperty,
}

PROPERTY_TO_NAME = {v: k for k, v in NAME_TO_PROPERTY.items()}

@classmethod
def from_dict(cls, properties_dict) -> Properties:
expressions = []
for key, value in properties_dict.items():
property_cls = cls.PROPERTY_KEY_MAPPING.get(key.upper(), AnonymousProperty)
expressions.append(property_cls(this=Literal.string(key), value=convert(value)))
property_cls = cls.NAME_TO_PROPERTY.get(key.upper())
if property_cls:
expressions.append(property_cls(this=convert(value)))
else:
expressions.append(Property(this=Literal.string(key), value=convert(value)))

return cls(expressions=expressions)


Expand Down
41 changes: 18 additions & 23 deletions sqlglot/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,11 +58,11 @@ class Generator:
"""

TRANSFORMS = {
exp.CharacterSetProperty: lambda self, e: f"{'DEFAULT ' if e.args['default'] else ''}CHARACTER SET={self.sql(e, 'value')}",
exp.DateAdd: lambda self, e: f"DATE_ADD({self.format_args(e.this, e.expression, e.args.get('unit'))})",
exp.DateDiff: lambda self, e: f"DATEDIFF({self.format_args(e.this, e.expression)})",
exp.TsOrDsAdd: lambda self, e: f"TS_OR_DS_ADD({self.format_args(e.this, e.expression, e.args.get('unit'))})",
exp.VarMap: lambda self, e: f"MAP({self.format_args(e.args['keys'], e.args['values'])})",
exp.CharacterSetProperty: lambda self, e: f"{'DEFAULT ' if e.args['default'] else ''}CHARACTER SET={self.sql(e, 'this')}",
exp.LanguageProperty: lambda self, e: self.naked_property(e),
exp.LocationProperty: lambda self, e: self.naked_property(e),
exp.ReturnsProperty: lambda self, e: self.naked_property(e),
Expand Down Expand Up @@ -100,7 +100,7 @@ class Generator:
}

WITH_PROPERTIES = {
exp.AnonymousProperty,
exp.Property,
exp.FileFormatProperty,
exp.PartitionedByProperty,
exp.TableFormatProperty,
Expand Down Expand Up @@ -546,36 +546,28 @@ def properties_sql(self, expression):

def root_properties(self, properties):
if properties.expressions:
return self.sep() + self.expressions(
properties,
indent=False,
sep=" ",
)
return self.sep() + self.expressions(properties, indent=False, sep=" ")
return ""

def properties(self, properties, prefix="", sep=", "):
if properties.expressions:
expressions = self.expressions(
properties,
sep=sep,
indent=False,
)
expressions = self.expressions(properties, sep=sep, indent=False)
return f"{self.seg(prefix)}{' ' if prefix else ''}{self.wrap(expressions)}"
return ""

def with_properties(self, properties):
return self.properties(
properties,
prefix="WITH",
)
return self.properties(properties, prefix="WITH")

def property_sql(self, expression):
if isinstance(expression.this, exp.Literal):
key = expression.this.this
else:
key = expression.name
value = self.sql(expression, "value")
return f"{key}={value}"
property_cls = expression.__class__
if property_cls == exp.Property:
return f"{expression.name}={self.sql(expression, 'value')}"

property_name = exp.Properties.PROPERTY_TO_NAME.get(property_cls)
if not property_name:
self.unsupported(f"Unsupported property {property_name}")

return f"{property_name}={self.sql(expression, 'this')}"

def insert_sql(self, expression):
overwrite = expression.args.get("overwrite")
Expand Down Expand Up @@ -1354,7 +1346,10 @@ def op_expressions(self, op, expression, flat=False):
return f"{self.seg(op)}{self.sep() if expressions_sql else ''}{expressions_sql}"

def naked_property(self, expression):
return f"{expression.name} {self.sql(expression, 'value')}"
property_name = exp.Properties.PROPERTY_TO_NAME.get(expression.__class__)
if not property_name:
self.unsupported(f"Unsupported property {expression.__class__.__name__}")
return f"{property_name} {self.sql(expression, 'this')}"

def set_operation(self, expression, op):
this = self.sql(expression, "this")
Expand Down
Loading

0 comments on commit be332d1

Please sign in to comment.