diff --git a/sqlglot/dialects/snowflake.py b/sqlglot/dialects/snowflake.py index 18328902df..0925238a97 100644 --- a/sqlglot/dialects/snowflake.py +++ b/sqlglot/dialects/snowflake.py @@ -105,7 +105,7 @@ def _parse_date_part(self: parser.Parser) -> t.Optional[exp.Expression]: scale = None ts = self.expression(exp.Cast, this=expression, to=exp.DataType.build("TIMESTAMP")) - to_unix = self.expression(exp.TimeToUnix, this=ts) + to_unix: exp.Expression = self.expression(exp.TimeToUnix, this=ts) if scale: to_unix = exp.Mul(this=to_unix, expression=exp.Literal.number(scale)) diff --git a/sqlglot/dialects/teradata.py b/sqlglot/dialects/teradata.py index 3d43793db5..2a4c43dce5 100644 --- a/sqlglot/dialects/teradata.py +++ b/sqlglot/dialects/teradata.py @@ -1,7 +1,14 @@ from __future__ import annotations +import typing as t + from sqlglot import exp, generator, parser, tokens -from sqlglot.dialects.dialect import Dialect, max_or_greatest, min_or_least +from sqlglot.dialects.dialect import ( + Dialect, + format_time_lambda, + max_or_greatest, + min_or_least, +) from sqlglot.tokens import TokenType @@ -115,6 +122,14 @@ def _parse_rangen(self): return self.expression(exp.RangeN, this=this, expressions=expressions, each=each) + def _parse_cast(self, strict: bool) -> exp.Expression: + cast = t.cast(exp.Cast, super()._parse_cast(strict)) + if cast.to.this == exp.DataType.Type.DATE and self._match(TokenType.FORMAT): + return format_time_lambda(exp.TimeToStr, "teradata")( + [cast.this, self._parse_string()] + ) + return cast + class Generator(generator.Generator): TYPE_MAPPING = { **generator.Generator.TYPE_MAPPING, # type: ignore @@ -130,6 +145,7 @@ class Generator(generator.Generator): **generator.Generator.TRANSFORMS, exp.Max: max_or_greatest, exp.Min: min_or_least, + exp.TimeToStr: lambda self, e: f"CAST({self.sql(e, 'this')} AS DATE FORMAT {self.format_time(e)})", exp.ToChar: lambda self, e: self.function_fallback_sql(e), } diff --git a/sqlglot/parser.py b/sqlglot/parser.py index 36b6adcbba..f251f02097 100644 --- a/sqlglot/parser.py +++ b/sqlglot/parser.py @@ -18,6 +18,8 @@ logger = logging.getLogger("sqlglot") +E = t.TypeVar("E", bound=exp.Expression) + def parse_var_map(args: t.Sequence) -> exp.Expression: keys = [] @@ -927,8 +929,8 @@ def raise_error(self, message: str, token: t.Optional[Token] = None) -> None: self.errors.append(error) def expression( - self, exp_class: t.Type[exp.Expression], comments: t.Optional[t.List[str]] = None, **kwargs - ) -> exp.Expression: + self, exp_class: t.Type[E], comments: t.Optional[t.List[str]] = None, **kwargs + ) -> E: """ Creates a new, validated Expression. @@ -984,7 +986,7 @@ def _retreat(self, index: int) -> None: if index != self._index: self._advance(index - self._index) - def _parse_command(self) -> exp.Expression: + def _parse_command(self) -> exp.Command: return self.expression(exp.Command, this=self._prev.text, expression=self._parse_string()) def _parse_comment(self, allow_exists: bool = True) -> exp.Expression: @@ -1029,7 +1031,7 @@ def _parse_statement(self) -> t.Optional[exp.Expression]: self._parse_query_modifiers(expression) return expression - def _parse_drop(self) -> t.Optional[exp.Expression]: + def _parse_drop(self) -> t.Optional[exp.Drop | exp.Command]: start = self._prev temporary = self._match(TokenType.TEMPORARY) materialized = self._match(TokenType.MATERIALIZED) @@ -4065,7 +4067,7 @@ def _parse_merge(self) -> exp.Expression: if self._match(TokenType.INSERT): _this = self._parse_star() if _this: - then = self.expression(exp.Insert, this=_this) + then: t.Optional[exp.Expression] = self.expression(exp.Insert, this=_this) else: then = self.expression( exp.Insert, diff --git a/tests/dialects/test_teradata.py b/tests/dialects/test_teradata.py index 5d4f7db08c..03eba44f1e 100644 --- a/tests/dialects/test_teradata.py +++ b/tests/dialects/test_teradata.py @@ -71,3 +71,15 @@ def test_datatype(self): ) self.validate_identity("CREATE TABLE z (a SYSUDTLIB.INT)") + + def test_cast(self): + self.validate_all( + "CAST('1992-01' AS DATE FORMAT 'YYYY-DD')", + write={ + "teradata": "CAST('1992-01' AS DATE FORMAT 'YYYY-DD')", + "databricks": "DATE_FORMAT('1992-01', 'YYYY-DD')", + "mysql": "DATE_FORMAT('1992-01', 'YYYY-DD')", + "spark": "DATE_FORMAT('1992-01', 'YYYY-DD')", + "": "TIME_TO_STR('1992-01', 'YYYY-DD')", + }, + )