Skip to content

Commit

Permalink
Fix(clickhouse)!: use creatable kind mapping dict for schema<-->datab…
Browse files Browse the repository at this point in the history
…ase substitution (#3924)

* Transpile schema to database in drop statements

* Use kind mapping dict instead of method overrides
  • Loading branch information
treysp authored Aug 18, 2024
1 parent c59cd94 commit 605f1b2
Show file tree
Hide file tree
Showing 6 changed files with 35 additions and 16 deletions.
15 changes: 2 additions & 13 deletions sqlglot/dialects/clickhouse.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,8 @@ class ClickHouse(Dialect):
"\\0": "\0",
}

CREATABLE_KIND_MAPPING = {"DATABASE": "SCHEMA"}

class Tokenizer(tokens.Tokenizer):
COMMENTS = ["--", "#", "#!", ("/*", "*/")]
IDENTIFIERS = ['"', "`"]
Expand Down Expand Up @@ -444,15 +446,6 @@ def _parse_types(

return dtype

def _parse_create(self) -> exp.Create | exp.Command:
create = super()._parse_create()

# DATABASE in ClickHouse is the same as SCHEMA in other dialects
if isinstance(create, exp.Create) and create.kind == "DATABASE":
create.set("kind", "SCHEMA")

return create

def _parse_extract(self) -> exp.Extract | exp.Anonymous:
index = self._index
this = self._parse_bitwise()
Expand Down Expand Up @@ -1046,10 +1039,6 @@ def create_sql(self, expression: exp.Create) -> str:
else:
comment_prop = None

# ClickHouse only has DATABASEs and objects under them, eg. TABLEs, VIEWs, etc
if expression.kind == "SCHEMA":
expression.set("kind", "DATABASE")

create_sql = super().create_sql(expression)

comment_sql = self.sql(comment_prop)
Expand Down
12 changes: 12 additions & 0 deletions sqlglot/dialects/dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,10 @@ def __new__(cls, clsname, bases, attrs):
klass.INVERSE_FORMAT_MAPPING = {v: k for k, v in klass.FORMAT_MAPPING.items()}
klass.INVERSE_FORMAT_TRIE = new_trie(klass.INVERSE_FORMAT_MAPPING)

klass.INVERSE_CREATABLE_KIND_MAPPING = {
v: k for k, v in klass.CREATABLE_KIND_MAPPING.items()
}

base = seq_get(bases, 0)
base_tokenizer = (getattr(base, "tokenizer_class", Tokenizer),)
base_jsonpath_tokenizer = (getattr(base, "jsonpath_tokenizer_class", JSONPathTokenizer),)
Expand Down Expand Up @@ -384,6 +388,12 @@ class Dialect(metaclass=_Dialect):
dialects which don't support fixed size arrays such as Snowflake, this should be interpreted as a subscript/index operator
"""

CREATABLE_KIND_MAPPING: dict[str, str] = {}
"""
Helper for dialects that use a different name for the same creatable kind. For example, the Clickhouse
equivalent of CREATE SCHEMA is CREATE DATABASE.
"""

# --- Autofilled ---

tokenizer_class = Tokenizer
Expand All @@ -400,6 +410,8 @@ class Dialect(metaclass=_Dialect):
INVERSE_FORMAT_MAPPING: t.Dict[str, str] = {}
INVERSE_FORMAT_TRIE: t.Dict = {}

INVERSE_CREATABLE_KIND_MAPPING: dict[str, str] = {}

ESCAPED_SEQUENCES: t.Dict[str, str] = {}

# Delimiters for string literals and identifiers
Expand Down
5 changes: 5 additions & 0 deletions sqlglot/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2001,6 +2001,11 @@ class Drop(Expression):
"cluster": False,
}

@property
def kind(self) -> t.Optional[str]:
kind = self.args.get("kind")
return kind and kind.upper()


class Filter(Expression):
arg_types = {"this": True, "expression": True}
Expand Down
2 changes: 2 additions & 0 deletions sqlglot/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -980,6 +980,7 @@ def createable_sql(self, expression: exp.Create, locations: t.DefaultDict) -> st

def create_sql(self, expression: exp.Create) -> str:
kind = self.sql(expression, "kind")
kind = self.dialect.INVERSE_CREATABLE_KIND_MAPPING.get(kind) or kind
properties = expression.args.get("properties")
properties_locs = self.locate_properties(properties) if properties else defaultdict()

Expand Down Expand Up @@ -1278,6 +1279,7 @@ def drop_sql(self, expression: exp.Drop) -> str:
expressions = self.expressions(expression, flat=True)
expressions = f" ({expressions})" if expressions else ""
kind = expression.args["kind"]
kind = self.dialect.INVERSE_CREATABLE_KIND_MAPPING.get(kind) or kind
exists_sql = " IF EXISTS " if expression.args.get("exists") else " "
on_cluster = self.sql(expression, "cluster")
on_cluster = f" {on_cluster}" if on_cluster else ""
Expand Down
7 changes: 4 additions & 3 deletions sqlglot/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -1665,7 +1665,7 @@ def _parse_drop(self, exists: bool = False) -> exp.Drop | exp.Command:
temporary = self._match(TokenType.TEMPORARY)
materialized = self._match_text_seq("MATERIALIZED")

kind = self._match_set(self.CREATABLES) and self._prev.text
kind = self._match_set(self.CREATABLES) and self._prev.text.upper()
if not kind:
return self._parse_as_command(start)

Expand All @@ -1687,7 +1687,7 @@ def _parse_drop(self, exists: bool = False) -> exp.Drop | exp.Command:
exists=if_exists,
this=table,
expressions=expressions,
kind=kind.upper(),
kind=self.dialect.CREATABLE_KIND_MAPPING.get(kind) or kind,
temporary=temporary,
materialized=materialized,
cascade=self._match_text_seq("CASCADE"),
Expand Down Expand Up @@ -1850,11 +1850,12 @@ def extend_props(temp_props: t.Optional[exp.Properties]) -> None:
if self._curr and not self._match_set((TokenType.R_PAREN, TokenType.COMMA), advance=False):
return self._parse_as_command(start)

create_kind_text = create_token.text.upper()
return self.expression(
exp.Create,
comments=comments,
this=this,
kind=create_token.text.upper(),
kind=self.dialect.CREATABLE_KIND_MAPPING.get(create_kind_text) or create_kind_text,
replace=replace,
refresh=refresh,
unique=unique,
Expand Down
10 changes: 10 additions & 0 deletions tests/dialects/test_clickhouse.py
Original file line number Diff line number Diff line change
Expand Up @@ -672,6 +672,16 @@ def test_ddl(self):
"duckdb": "CREATE SCHEMA x",
},
)
self.validate_all(
"DROP DATABASE x",
read={
"duckdb": "DROP SCHEMA x",
},
write={
"clickhouse": "DROP DATABASE x",
"duckdb": "DROP SCHEMA x",
},
)
self.validate_all(
"""
CREATE TABLE example1 (
Expand Down

0 comments on commit 605f1b2

Please sign in to comment.