Skip to content

Commit

Permalink
Added support for multiple T-SQL features (#890)
Browse files Browse the repository at this point in the history
* Fixed several issues regarding T-SQL dialect

* Fixed commented sysdatetime

* Removed the arg_types

* Removed values wrapper

* Added wrapper, will remove this in a later PR

* Restructured the code

* Update sqlglot/dialects/tsql.py

Co-authored-by: Jo <46752250+GeorgeSittas@users.noreply.github.com>

Co-authored-by: Sebastiaan Fransen <s.fransen@vanlanschot.com>
Co-authored-by: Toby Mao <toby.mao@gmail.com>
Co-authored-by: Jo <46752250+GeorgeSittas@users.noreply.github.com>
  • Loading branch information
4 people committed Jan 6, 2023
1 parent 52db84c commit 1ac05d9
Show file tree
Hide file tree
Showing 7 changed files with 55 additions and 7 deletions.
3 changes: 3 additions & 0 deletions sqlglot/dialects/hive.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,7 @@ class Generator(generator.Generator):
exp.RowFormatSerdeProperty: lambda self, e: f"ROW FORMAT SERDE {self.sql(e, 'this')}",
exp.SerdeProperties: lambda self, e: self.properties(e, prefix="WITH SERDEPROPERTIES"),
exp.NumberToStr: rename_func("FORMAT_NUMBER"),
exp.LastDateOfMonth: rename_func("LAST_DAY"),
}

WITH_PROPERTIES = {exp.Property}
Expand Down Expand Up @@ -342,4 +343,6 @@ def datatype_sql(self, expression):
and not expression.expressions
):
expression = exp.DataType.build("text")
elif expression.this in exp.DataType.TEMPORAL_TYPES:
expression = exp.DataType.build(expression.this)
return super().datatype_sql(expression)
19 changes: 18 additions & 1 deletion sqlglot/dialects/tsql.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,20 @@ def _parse_format(args):
)


def _parse_eomonth(args):
date = seq_get(args, 0)
month_lag = seq_get(args, 1)
unit = DATE_DELTA_INTERVAL.get("month")

if month_lag is None:
return exp.LastDateOfMonth(this=date)

# Remove month lag argument in parser as its compared with the number of arguments of the resulting class
args.remove(month_lag)

return exp.LastDateOfMonth(this=exp.DateAdd(this=date, expression=month_lag, unit=unit))


def generate_date_delta_with_unit_sql(self, e):
func = "DATEADD" if isinstance(e, exp.DateAdd) else "DATEDIFF"
return f"{func}({self.format_args(e.text('unit'), e.expression, e.this)})"
Expand Down Expand Up @@ -256,12 +270,14 @@ class Parser(parser.Parser):
"DATEDIFF": parse_date_delta(exp.DateDiff, unit_mapping=DATE_DELTA_INTERVAL),
"DATENAME": _format_time_lambda(exp.TimeToStr, full_format_mapping=True),
"DATEPART": _format_time_lambda(exp.TimeToStr),
"GETDATE": exp.CurrentDate.from_arg_list,
"GETDATE": exp.CurrentTimestamp.from_arg_list,
"SYSDATETIME": exp.CurrentTimestamp.from_arg_list,
"IIF": exp.If.from_arg_list,
"LEN": exp.Length.from_arg_list,
"REPLICATE": exp.Repeat.from_arg_list,
"JSON_VALUE": exp.JSONExtractScalar.from_arg_list,
"FORMAT": _parse_format,
"EOMONTH": _parse_eomonth,
}

VAR_LENGTH_DATATYPES = {
Expand Down Expand Up @@ -326,6 +342,7 @@ class Generator(generator.Generator):
exp.DateAdd: generate_date_delta_with_unit_sql,
exp.DateDiff: generate_date_delta_with_unit_sql,
exp.CurrentDate: rename_func("GETDATE"),
exp.CurrentTimestamp: rename_func("GETDATE"),
exp.If: rename_func("IIF"),
exp.NumberToStr: _format_sql,
exp.TimeToStr: _format_sql,
Expand Down
4 changes: 4 additions & 0 deletions sqlglot/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2679,6 +2679,10 @@ class DatetimeTrunc(Func, TimeUnit):
arg_types = {"this": True, "unit": True, "zone": False}


class LastDateOfMonth(Func):
pass


class Extract(Func):
arg_types = {"this": True, "expression": True}

Expand Down
8 changes: 5 additions & 3 deletions sqlglot/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -1264,12 +1264,12 @@ def _parse_table_alias(self, alias_tokens=None):

return self.expression(exp.TableAlias, this=alias, columns=columns)

def _parse_subquery(self, this):
def _parse_subquery(self, this, parse_alias=True):
return self.expression(
exp.Subquery,
this=this,
pivots=self._parse_pivots(),
alias=self._parse_table_alias(),
alias=self._parse_table_alias() if parse_alias else None,
)

def _parse_query_modifiers(self, this):
Expand Down Expand Up @@ -2018,7 +2018,9 @@ def _parse_primary(self):
self._match_r_paren()

if isinstance(this, exp.Subqueryable):
this = self._parse_set_operations(self._parse_subquery(this))
this = self._parse_set_operations(
self._parse_subquery(this=this, parse_alias=False)
)
elif len(expressions) > 1:
this = self.expression(exp.Tuple, expressions=expressions)
else:
Expand Down
2 changes: 1 addition & 1 deletion tests/dialects/test_bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def test_bigquery(self):
},
)
self.validate_all(
"CURRENT_DATE",
"CURRENT_TIMESTAMP()",
read={
"tsql": "GETDATE()",
},
Expand Down
4 changes: 2 additions & 2 deletions tests/dialects/test_presto.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,8 @@ def test_cast(self):
"bigquery": "CAST(x AS TIMESTAMPTZ(9))",
"duckdb": "CAST(x AS TIMESTAMPTZ(9))",
"presto": "CAST(x AS TIMESTAMP(9) WITH TIME ZONE)",
"hive": "CAST(x AS TIMESTAMPTZ(9))",
"spark": "CAST(x AS TIMESTAMPTZ(9))",
"hive": "CAST(x AS TIMESTAMPTZ)",
"spark": "CAST(x AS TIMESTAMPTZ)",
},
)

Expand Down
22 changes: 22 additions & 0 deletions tests/dialects/test_tsql.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,12 @@ def test_types(self):
"tsql": "CAST(x AS DATETIME2)",
},
)
self.validate_all(
"CAST(x AS DATETIME2(6))",
write={
"hive": "CAST(x AS TIMESTAMP)",
},
)

def test_charindex(self):
self.validate_all(
Expand Down Expand Up @@ -302,6 +308,12 @@ def test_convert_date_format(self):
"spark": "SELECT CAST(y.x AS VARCHAR(10)) AS z FROM testdb.dbo.test AS y",
},
)
self.validate_all(
"SELECT CAST((SELECT x FROM y) AS VARCHAR) AS test",
write={
"spark": "SELECT CAST((SELECT x FROM y) AS STRING) AS test",
},
)

def test_add_date(self):
self.validate_identity("SELECT DATEADD(year, 1, '2017/08/25')")
Expand Down Expand Up @@ -443,3 +455,13 @@ def test_string(self):
"SELECT '''test'''",
write={"spark": r"SELECT '\'test\''"},
)

def test_eomonth(self):
self.validate_all(
"EOMONTH(GETDATE())",
write={"spark": "LAST_DAY(CURRENT_TIMESTAMP())"},
)
self.validate_all(
"EOMONTH(GETDATE(), -1)",
write={"spark": "LAST_DAY(ADD_MONTHS(CURRENT_TIMESTAMP(), -1))"},
)

0 comments on commit 1ac05d9

Please sign in to comment.