diff --git a/sqlglot/dialects/hive.py b/sqlglot/dialects/hive.py index 088555c9b4..17d5b0c20a 100644 --- a/sqlglot/dialects/hive.py +++ b/sqlglot/dialects/hive.py @@ -315,6 +315,7 @@ class Generator(generator.Generator): exp.RowFormatSerdeProperty: lambda self, e: f"ROW FORMAT SERDE {self.sql(e, 'this')}", exp.SerdeProperties: lambda self, e: self.properties(e, prefix="WITH SERDEPROPERTIES"), exp.NumberToStr: rename_func("FORMAT_NUMBER"), + exp.LastDateOfMonth: rename_func("LAST_DAY"), } WITH_PROPERTIES = {exp.Property} @@ -342,4 +343,6 @@ def datatype_sql(self, expression): and not expression.expressions ): expression = exp.DataType.build("text") + elif expression.this in exp.DataType.TEMPORAL_TYPES: + expression = exp.DataType.build(expression.this) return super().datatype_sql(expression) diff --git a/sqlglot/dialects/tsql.py b/sqlglot/dialects/tsql.py index 9cedbcca72..465f53480d 100644 --- a/sqlglot/dialects/tsql.py +++ b/sqlglot/dialects/tsql.py @@ -75,6 +75,20 @@ def _parse_format(args): ) +def _parse_eomonth(args): + date = seq_get(args, 0) + month_lag = seq_get(args, 1) + unit = DATE_DELTA_INTERVAL.get("month") + + if month_lag is None: + return exp.LastDateOfMonth(this=date) + + # Remove month lag argument in parser as its compared with the number of arguments of the resulting class + args.remove(month_lag) + + return exp.LastDateOfMonth(this=exp.DateAdd(this=date, expression=month_lag, unit=unit)) + + def generate_date_delta_with_unit_sql(self, e): func = "DATEADD" if isinstance(e, exp.DateAdd) else "DATEDIFF" return f"{func}({self.format_args(e.text('unit'), e.expression, e.this)})" @@ -256,12 +270,14 @@ class Parser(parser.Parser): "DATEDIFF": parse_date_delta(exp.DateDiff, unit_mapping=DATE_DELTA_INTERVAL), "DATENAME": _format_time_lambda(exp.TimeToStr, full_format_mapping=True), "DATEPART": _format_time_lambda(exp.TimeToStr), - "GETDATE": exp.CurrentDate.from_arg_list, + "GETDATE": exp.CurrentTimestamp.from_arg_list, + "SYSDATETIME": exp.CurrentTimestamp.from_arg_list, "IIF": exp.If.from_arg_list, "LEN": exp.Length.from_arg_list, "REPLICATE": exp.Repeat.from_arg_list, "JSON_VALUE": exp.JSONExtractScalar.from_arg_list, "FORMAT": _parse_format, + "EOMONTH": _parse_eomonth, } VAR_LENGTH_DATATYPES = { @@ -326,6 +342,7 @@ class Generator(generator.Generator): exp.DateAdd: generate_date_delta_with_unit_sql, exp.DateDiff: generate_date_delta_with_unit_sql, exp.CurrentDate: rename_func("GETDATE"), + exp.CurrentTimestamp: rename_func("GETDATE"), exp.If: rename_func("IIF"), exp.NumberToStr: _format_sql, exp.TimeToStr: _format_sql, diff --git a/sqlglot/expressions.py b/sqlglot/expressions.py index 544849bfc5..2d05f3e813 100644 --- a/sqlglot/expressions.py +++ b/sqlglot/expressions.py @@ -2679,6 +2679,10 @@ class DatetimeTrunc(Func, TimeUnit): arg_types = {"this": True, "unit": True, "zone": False} +class LastDateOfMonth(Func): + pass + + class Extract(Func): arg_types = {"this": True, "expression": True} diff --git a/sqlglot/parser.py b/sqlglot/parser.py index dbac57d040..36a923c476 100644 --- a/sqlglot/parser.py +++ b/sqlglot/parser.py @@ -1264,12 +1264,12 @@ def _parse_table_alias(self, alias_tokens=None): return self.expression(exp.TableAlias, this=alias, columns=columns) - def _parse_subquery(self, this): + def _parse_subquery(self, this, parse_alias=True): return self.expression( exp.Subquery, this=this, pivots=self._parse_pivots(), - alias=self._parse_table_alias(), + alias=self._parse_table_alias() if parse_alias else None, ) def _parse_query_modifiers(self, this): @@ -2018,7 +2018,9 @@ def _parse_primary(self): self._match_r_paren() if isinstance(this, exp.Subqueryable): - this = self._parse_set_operations(self._parse_subquery(this)) + this = self._parse_set_operations( + self._parse_subquery(this=this, parse_alias=False) + ) elif len(expressions) > 1: this = self.expression(exp.Tuple, expressions=expressions) else: diff --git a/tests/dialects/test_bigquery.py b/tests/dialects/test_bigquery.py index 258e47fd0d..f96bb66098 100644 --- a/tests/dialects/test_bigquery.py +++ b/tests/dialects/test_bigquery.py @@ -125,7 +125,7 @@ def test_bigquery(self): }, ) self.validate_all( - "CURRENT_DATE", + "CURRENT_TIMESTAMP()", read={ "tsql": "GETDATE()", }, diff --git a/tests/dialects/test_presto.py b/tests/dialects/test_presto.py index 70e1059828..32eaf3dc09 100644 --- a/tests/dialects/test_presto.py +++ b/tests/dialects/test_presto.py @@ -63,8 +63,8 @@ def test_cast(self): "bigquery": "CAST(x AS TIMESTAMPTZ(9))", "duckdb": "CAST(x AS TIMESTAMPTZ(9))", "presto": "CAST(x AS TIMESTAMP(9) WITH TIME ZONE)", - "hive": "CAST(x AS TIMESTAMPTZ(9))", - "spark": "CAST(x AS TIMESTAMPTZ(9))", + "hive": "CAST(x AS TIMESTAMPTZ)", + "spark": "CAST(x AS TIMESTAMPTZ)", }, ) diff --git a/tests/dialects/test_tsql.py b/tests/dialects/test_tsql.py index f980e8b46b..b74c05f764 100644 --- a/tests/dialects/test_tsql.py +++ b/tests/dialects/test_tsql.py @@ -73,6 +73,12 @@ def test_types(self): "tsql": "CAST(x AS DATETIME2)", }, ) + self.validate_all( + "CAST(x AS DATETIME2(6))", + write={ + "hive": "CAST(x AS TIMESTAMP)", + }, + ) def test_charindex(self): self.validate_all( @@ -302,6 +308,12 @@ def test_convert_date_format(self): "spark": "SELECT CAST(y.x AS VARCHAR(10)) AS z FROM testdb.dbo.test AS y", }, ) + self.validate_all( + "SELECT CAST((SELECT x FROM y) AS VARCHAR) AS test", + write={ + "spark": "SELECT CAST((SELECT x FROM y) AS STRING) AS test", + }, + ) def test_add_date(self): self.validate_identity("SELECT DATEADD(year, 1, '2017/08/25')") @@ -443,3 +455,13 @@ def test_string(self): "SELECT '''test'''", write={"spark": r"SELECT '\'test\''"}, ) + + def test_eomonth(self): + self.validate_all( + "EOMONTH(GETDATE())", + write={"spark": "LAST_DAY(CURRENT_TIMESTAMP())"}, + ) + self.validate_all( + "EOMONTH(GETDATE(), -1)", + write={"spark": "LAST_DAY(ADD_MONTHS(CURRENT_TIMESTAMP(), -1))"}, + )