Skip to content

Commit

Permalink
handle teradata date format
Browse files Browse the repository at this point in the history
  • Loading branch information
tobymao committed Apr 21, 2023
1 parent f7eafce commit 35518cf
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 7 deletions.
2 changes: 1 addition & 1 deletion sqlglot/dialects/snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def _parse_date_part(self: parser.Parser) -> t.Optional[exp.Expression]:
scale = None

ts = self.expression(exp.Cast, this=expression, to=exp.DataType.build("TIMESTAMP"))
to_unix = self.expression(exp.TimeToUnix, this=ts)
to_unix: exp.Expression = self.expression(exp.TimeToUnix, this=ts)

if scale:
to_unix = exp.Mul(this=to_unix, expression=exp.Literal.number(scale))
Expand Down
18 changes: 17 additions & 1 deletion sqlglot/dialects/teradata.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,14 @@
from __future__ import annotations

import typing as t

from sqlglot import exp, generator, parser, tokens
from sqlglot.dialects.dialect import Dialect, max_or_greatest, min_or_least
from sqlglot.dialects.dialect import (
Dialect,
format_time_lambda,
max_or_greatest,
min_or_least,
)
from sqlglot.tokens import TokenType


Expand Down Expand Up @@ -115,6 +122,14 @@ def _parse_rangen(self):

return self.expression(exp.RangeN, this=this, expressions=expressions, each=each)

def _parse_cast(self, strict: bool) -> exp.Expression:
cast = t.cast(exp.Cast, super()._parse_cast(strict))
if cast.to.this == exp.DataType.Type.DATE and self._match(TokenType.FORMAT):
return format_time_lambda(exp.TimeToStr, "teradata")(
[cast.this, self._parse_string()]
)
return cast

class Generator(generator.Generator):
TYPE_MAPPING = {
**generator.Generator.TYPE_MAPPING, # type: ignore
Expand All @@ -130,6 +145,7 @@ class Generator(generator.Generator):
**generator.Generator.TRANSFORMS,
exp.Max: max_or_greatest,
exp.Min: min_or_least,
exp.TimeToStr: lambda self, e: f"CAST({self.sql(e, 'this')} AS DATE FORMAT {self.format_time(e)})",
exp.ToChar: lambda self, e: self.function_fallback_sql(e),
}

Expand Down
12 changes: 7 additions & 5 deletions sqlglot/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@

logger = logging.getLogger("sqlglot")

E = t.TypeVar("E", bound=exp.Expression)


def parse_var_map(args: t.Sequence) -> exp.Expression:
keys = []
Expand Down Expand Up @@ -927,8 +929,8 @@ def raise_error(self, message: str, token: t.Optional[Token] = None) -> None:
self.errors.append(error)

def expression(
self, exp_class: t.Type[exp.Expression], comments: t.Optional[t.List[str]] = None, **kwargs
) -> exp.Expression:
self, exp_class: t.Type[E], comments: t.Optional[t.List[str]] = None, **kwargs
) -> E:
"""
Creates a new, validated Expression.
Expand Down Expand Up @@ -984,7 +986,7 @@ def _retreat(self, index: int) -> None:
if index != self._index:
self._advance(index - self._index)

def _parse_command(self) -> exp.Expression:
def _parse_command(self) -> exp.Command:
return self.expression(exp.Command, this=self._prev.text, expression=self._parse_string())

def _parse_comment(self, allow_exists: bool = True) -> exp.Expression:
Expand Down Expand Up @@ -1029,7 +1031,7 @@ def _parse_statement(self) -> t.Optional[exp.Expression]:
self._parse_query_modifiers(expression)
return expression

def _parse_drop(self) -> t.Optional[exp.Expression]:
def _parse_drop(self) -> t.Optional[exp.Drop | exp.Command]:
start = self._prev
temporary = self._match(TokenType.TEMPORARY)
materialized = self._match(TokenType.MATERIALIZED)
Expand Down Expand Up @@ -4065,7 +4067,7 @@ def _parse_merge(self) -> exp.Expression:
if self._match(TokenType.INSERT):
_this = self._parse_star()
if _this:
then = self.expression(exp.Insert, this=_this)
then: t.Optional[exp.Expression] = self.expression(exp.Insert, this=_this)
else:
then = self.expression(
exp.Insert,
Expand Down
12 changes: 12 additions & 0 deletions tests/dialects/test_teradata.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,3 +71,15 @@ def test_datatype(self):
)

self.validate_identity("CREATE TABLE z (a SYSUDTLIB.INT)")

def test_cast(self):
self.validate_all(
"CAST('1992-01' AS DATE FORMAT 'YYYY-DD')",
write={
"teradata": "CAST('1992-01' AS DATE FORMAT 'YYYY-DD')",
"databricks": "DATE_FORMAT('1992-01', 'YYYY-DD')",
"mysql": "DATE_FORMAT('1992-01', 'YYYY-DD')",
"spark": "DATE_FORMAT('1992-01', 'YYYY-DD')",
"": "TIME_TO_STR('1992-01', 'YYYY-DD')",
},
)

0 comments on commit 35518cf

Please sign in to comment.