Skip to content

Commit

Permalink
Add transaction/commit expressions (#684)
Browse files Browse the repository at this point in the history
* Add transaction expression

* pr feedback, fix bigquery issue

* Add Commit expression

* Bring rollback token back

* comment/test for bigquery BEGIN block statement

* Simplify parse transaction

* add expression node for rollback, cleanup

* fixup
  • Loading branch information
georgesittas authored Nov 10, 2022
1 parent 968a0be commit 148282e
Show file tree
Hide file tree
Showing 16 changed files with 200 additions and 38 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ sqlglot.transpile("SELECT STRFTIME(x, '%y-%-m-%S')", read="duckdb", write="hive"
"SELECT DATE_FORMAT(x, 'yy-M-ss')"
```

As another example, let's suppose that we want to read in a SQL query that contains a CTE and a cast to `REAL`, and then transpile it to Spark, which uses backticks as identifiers and `FLOAT` instead of `REAL`:
As another example, let's suppose that we want to read in a SQL query that contains a CTE and a cast to `REAL`, and then transpile it to Spark, which uses backticks for identifiers and `FLOAT` instead of `REAL`:

```python
import sqlglot
Expand Down
15 changes: 15 additions & 0 deletions sqlglot/dialects/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,12 @@ class Parser(parser.Parser):
TokenType.TABLE,
}

# BEGIN signifies the start of a block statement, so it's different from BEGIN TRANSACTION
def _parse_transaction(self):
if self._match_text_seq("TRANSACTION"):
return self.expression(exp.Transaction)
return self.expression(exp.Command, this=self._prev.text)

class Generator(generator.Generator):
TRANSFORMS = {
**generator.Generator.TRANSFORMS,
Expand Down Expand Up @@ -204,6 +210,15 @@ class Generator(generator.Generator):

EXPLICIT_UNION = True

def transaction_sql(self, *_):
return "BEGIN TRANSACTION"

def commit_sql(self, *_):
return "COMMIT TRANSACTION"

def rollback_sql(self, *_):
return "ROLLBACK TRANSACTION"

def in_unnest_op(self, unnest):
return self.sql(unnest)

Expand Down
29 changes: 15 additions & 14 deletions sqlglot/dialects/mysql.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ class Tokenizer(tokens.Tokenizer):

KEYWORDS = {
**tokens.Tokenizer.KEYWORDS,
"START": TokenType.BEGIN,
"SEPARATOR": TokenType.SEPARATOR,
"_ARMSCII8": TokenType.INTRODUCER,
"_ASCII": TokenType.INTRODUCER,
Expand Down Expand Up @@ -281,36 +282,36 @@ class Parser(parser.Parser):
def _parse_show_mysql(self, this, target=False, full=None, global_=None):
if target:
if isinstance(target, str):
self._match_text(target)
self._match_text_seq(target)
target_id = self._parse_id_var()
else:
target_id = None

log = self._parse_string() if self._match_text("IN") else None
log = self._parse_string() if self._match_text_seq("IN") else None

if this in {"BINLOG EVENTS", "RELAYLOG EVENTS"}:
position = self._parse_number() if self._match_text("FROM") else None
position = self._parse_number() if self._match_text_seq("FROM") else None
db = None
else:
position = None
db = self._parse_id_var() if self._match_text("FROM") else None
db = self._parse_id_var() if self._match_text_seq("FROM") else None

channel = self._parse_id_var() if self._match_text("FOR", "CHANNEL") else None
channel = self._parse_id_var() if self._match_text_seq("FOR", "CHANNEL") else None

like = self._parse_string() if self._match_text("LIKE") else None
like = self._parse_string() if self._match_text_seq("LIKE") else None
where = self._parse_where()

if this == "PROFILE":
types = self._parse_csv(self._parse_show_profile_type)
query = self._parse_number() if self._match_text("FOR", "QUERY") else None
offset = self._parse_number() if self._match_text("OFFSET") else None
limit = self._parse_number() if self._match_text("LIMIT") else None
query = self._parse_number() if self._match_text_seq("FOR", "QUERY") else None
offset = self._parse_number() if self._match_text_seq("OFFSET") else None
limit = self._parse_number() if self._match_text_seq("LIMIT") else None
else:
types, query = None, None
offset, limit = self._parse_oldstyle_limit()

mutex = True if self._match_text("MUTEX") else None
mutex = False if self._match_text("STATUS") else mutex
mutex = True if self._match_text_seq("MUTEX") else None
mutex = False if self._match_text_seq("STATUS") else mutex

return self.expression(
exp.Show,
Expand All @@ -333,14 +334,14 @@ def _parse_show_mysql(self, this, target=False, full=None, global_=None):

def _parse_show_profile_type(self):
for type_ in self.PROFILE_TYPES:
if self._match_text(*type_.split(" ")):
if self._match_text_seq(*type_.split(" ")):
return exp.Var(this=type_)
return None

def _parse_oldstyle_limit(self):
limit = None
offset = None
if self._match_text("LIMIT"):
if self._match_text_seq("LIMIT"):
parts = self._parse_csv(self._parse_number)
if len(parts) == 1:
limit = parts[0]
Expand Down Expand Up @@ -381,7 +382,7 @@ def _parse_set_item_charset(self, kind):

def _parse_set_item_names(self):
charset = self._parse_string() or self._parse_id_var()
if self._match_text("COLLATE"):
if self._match_text_seq("COLLATE"):
collate = self._parse_string() or self._parse_id_var()
else:
collate = None
Expand Down
1 change: 1 addition & 0 deletions sqlglot/dialects/oracle.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ def table_sql(self, expression):
class Tokenizer(tokens.Tokenizer):
KEYWORDS = {
**tokens.Tokenizer.KEYWORDS,
"START": TokenType.BEGIN,
"TOP": TokenType.TOP,
"VARCHAR2": TokenType.VARCHAR,
"NVARCHAR2": TokenType.NVARCHAR,
Expand Down
6 changes: 6 additions & 0 deletions sqlglot/dialects/presto.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ class Presto(Dialect):
class Tokenizer(tokens.Tokenizer):
KEYWORDS = {
**tokens.Tokenizer.KEYWORDS,
"START": TokenType.BEGIN,
"ROW": TokenType.STRUCT,
}

Expand Down Expand Up @@ -216,3 +217,8 @@ class Generator(generator.Generator):
exp.UnixToTime: rename_func("FROM_UNIXTIME"),
exp.UnixToTimeStr: lambda self, e: f"CAST(FROM_UNIXTIME({self.sql(e, 'this')}) AS VARCHAR)",
}

def transaction_sql(self, expression):
modes = expression.args.get("modes")
modes = f" {', '.join(modes)}" if modes else ""
return f"START TRANSACTION{modes}"
5 changes: 5 additions & 0 deletions sqlglot/dialects/sqlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,3 +63,8 @@ class Generator(generator.Generator):
exp.TableSample: no_tablesample_sql,
exp.TryCast: no_trycast_sql,
}

def transaction_sql(self, expression):
this = expression.this
this = f" {this}" if this else ""
return f"BEGIN{this} TRANSACTION"
21 changes: 15 additions & 6 deletions sqlglot/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2054,16 +2054,25 @@ class Exists(SubqueryPredicate):
pass


# Commands to interact with the databases or engines
# These expressions don't truly parse the expression and consume
# whatever exists as a string until the end or a semicolon
# Commands to interact with the databases or engines. For most of the command
# expressions we parse whatever comes after the command's name as a string.
class Command(Expression):
arg_types = {"this": True, "expression": False}


# Binary Expressions
# (ADD a b)
# (FROM table selects)
class Transaction(Command):
arg_types = {"this": False, "modes": False}


class Commit(Command):
arg_types = {} # type: ignore


class Rollback(Command):
arg_types = {"savepoint": False}


# Binary expressions like (ADD a b)
class Binary(Expression):
arg_types = {"this": True, "expression": True}

Expand Down
11 changes: 11 additions & 0 deletions sqlglot/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1176,6 +1176,17 @@ def currentdate_sql(self, expression):
def command_sql(self, expression):
return f"{self.sql(expression, 'this').upper()} {expression.text('expression').strip()}"

def transaction_sql(self, *_):
return "BEGIN"

def commit_sql(self, *_):
return "COMMIT"

def rollback_sql(self, expression):
savepoint = expression.args.get("savepoint")
savepoint = f" TO {savepoint}" if savepoint else ""
return f"ROLLBACK{savepoint}"

def distinct_sql(self, expression):
this = self.expressions(expression, flat=True)
this = f" {this}" if this else ""
Expand Down
69 changes: 57 additions & 12 deletions sqlglot/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,6 +392,9 @@ class Parser(metaclass=_Parser):
TokenType.CACHE: lambda self: self._parse_cache(),
TokenType.UNCACHE: lambda self: self._parse_uncache(),
TokenType.USE: lambda self: self._parse_use(),
TokenType.BEGIN: lambda self: self._parse_transaction(),
TokenType.COMMIT: lambda self: self._parse_commit_or_rollback(),
TokenType.ROLLBACK: lambda self: self._parse_commit_or_rollback(),
}

PRIMARY_PARSERS = {
Expand Down Expand Up @@ -521,6 +524,8 @@ class Parser(metaclass=_Parser):
TokenType.SCHEMA,
}

TRANSACTION_KIND = {"DEFERRED", "IMMEDIATE", "EXCLUSIVE"}

STRICT_CAST = True

__slots__ = (
Expand Down Expand Up @@ -930,7 +935,7 @@ def _parse_describe(self):
def _parse_insert(self):
overwrite = self._match(TokenType.OVERWRITE)
local = self._match(TokenType.LOCAL)
if self._match_text("DIRECTORY"):
if self._match_text_seq("DIRECTORY"):
this = self.expression(
exp.Directory,
this=self._parse_var_or_string(),
Expand All @@ -954,27 +959,27 @@ def _parse_row_format(self):
if not self._match_pair(TokenType.ROW, TokenType.FORMAT):
return None

self._match_text("DELIMITED")
self._match_text_seq("DELIMITED")

kwargs = {}

if self._match_text("FIELDS", "TERMINATED", "BY"):
if self._match_text_seq("FIELDS", "TERMINATED", "BY"):
kwargs["fields"] = self._parse_string()
if self._match_text("ESCAPED", "BY"):
if self._match_text_seq("ESCAPED", "BY"):
kwargs["escaped"] = self._parse_string()
if self._match_text("COLLECTION", "ITEMS", "TERMINATED", "BY"):
if self._match_text_seq("COLLECTION", "ITEMS", "TERMINATED", "BY"):
kwargs["collection_items"] = self._parse_string()
if self._match_text("MAP", "KEYS", "TERMINATED", "BY"):
if self._match_text_seq("MAP", "KEYS", "TERMINATED", "BY"):
kwargs["map_keys"] = self._parse_string()
if self._match_text("LINES", "TERMINATED", "BY"):
if self._match_text_seq("LINES", "TERMINATED", "BY"):
kwargs["lines"] = self._parse_string()
if self._match_text("NULL", "DEFINED", "AS"):
if self._match_text_seq("NULL", "DEFINED", "AS"):
kwargs["null"] = self._parse_string()
return self.expression(exp.RowFormat, **kwargs)

def _parse_load_data(self):
local = self._match(TokenType.LOCAL)
self._match_text("INPATH")
self._match_text_seq("INPATH")
inpath = self._parse_string()
overwrite = self._match(TokenType.OVERWRITE)
self._match_pair(TokenType.INTO, TokenType.TABLE)
Expand All @@ -986,8 +991,8 @@ def _parse_load_data(self):
overwrite=overwrite,
inpath=inpath,
partition=self._parse_partition(),
input_format=self._match_text("INPUTFORMAT") and self._parse_string(),
serde=self._match_text("SERDE") and self._parse_string(),
input_format=self._match_text_seq("INPUTFORMAT") and self._parse_string(),
serde=self._match_text_seq("SERDE") and self._parse_string(),
)

def _parse_delete(self):
Expand Down Expand Up @@ -2594,6 +2599,40 @@ def _parse_select_or_expression(self):
def _parse_use(self):
return self.expression(exp.Use, this=self._parse_id_var())

def _parse_transaction(self):
this = None
if self._match_texts(self.TRANSACTION_KIND):
this = self._prev.text

self._match_texts({"TRANSACTION", "WORK"})

modes = []
while True:
mode = []
while self._match(TokenType.VAR):
mode.append(self._prev.text)

if mode:
modes.append(" ".join(mode))
if not self._match(TokenType.COMMA):
break

return self.expression(exp.Transaction, this=this, modes=modes)

def _parse_commit_or_rollback(self):
savepoint = None
is_rollback = self._prev.token_type == TokenType.ROLLBACK

self._match_texts({"TRANSACTION", "WORK"})

if self._match_text_seq("TO"):
self._match_text_seq("SAVEPOINT")
savepoint = self._parse_id_var()

if is_rollback:
return self.expression(exp.Rollback, savepoint=savepoint)
return self.expression(exp.Commit)

def _parse_show(self):
parser = self._find_parser(self.SHOW_PARSERS, self._show_trie)
if parser:
Expand Down Expand Up @@ -2675,7 +2714,13 @@ def _match_r_paren(self, expression=None):
if expression and self._prev_comment:
expression.comment = self._prev_comment

def _match_text(self, *texts):
def _match_texts(self, texts):
if self._curr and self._curr.text.upper() in texts:
self._advance()
return True
return False

def _match_text_seq(self, *texts):
index = self._index
for text in texts:
if self._curr and self._curr.text.upper() == text:
Expand Down
5 changes: 1 addition & 4 deletions sqlglot/tokens.py
Original file line number Diff line number Diff line change
Expand Up @@ -670,20 +670,17 @@ class Tokenizer(metaclass=_Tokenizer):
}

COMMANDS = {
TokenType.ALTER,
TokenType.ADD_FILE,
TokenType.ALTER,
TokenType.ANALYZE,
TokenType.BEGIN,
TokenType.CALL,
TokenType.COMMENT_ON,
TokenType.COMMIT,
TokenType.EXPLAIN,
TokenType.OPTIMIZE,
TokenType.SET,
TokenType.SHOW,
TokenType.TRUNCATE,
TokenType.VACUUM,
TokenType.ROLLBACK,
}

# handle numeric literals like in hive (3L = BIGINT)
Expand Down
4 changes: 4 additions & 0 deletions tests/dialects/test_bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,10 @@ def test_bigquery(self):
"bigquery": "SELECT * FROM (SELECT a, b, c FROM test) PIVOT(SUM(b) AS d, COUNT(*) AS e FOR c IN ('x', 'y'))",
},
)
self.validate_identity("BEGIN")
self.validate_identity("BEGIN TRANSACTION")
self.validate_identity("COMMIT TRANSACTION")
self.validate_identity("ROLLBACK TRANSACTION")

def test_user_defined_functions(self):
self.validate_identity(
Expand Down
Loading

0 comments on commit 148282e

Please sign in to comment.