From 148282e710fd79512bb7d32e6e519d631df8115d Mon Sep 17 00:00:00 2001 From: Jo <46752250+GeorgeSittas@users.noreply.github.com> Date: Fri, 11 Nov 2022 01:27:38 +0200 Subject: [PATCH] Add transaction/commit expressions (#684) * Add transaction expression * pr feedback, fix bigquery issue * Add Commit expression * Bring rollback token back * comment/test for bigquery BEGIN block statement * Simplify parse transaction * add expression node for rollback, cleanup * fixup --- README.md | 2 +- sqlglot/dialects/bigquery.py | 15 +++++++ sqlglot/dialects/mysql.py | 29 +++++++------- sqlglot/dialects/oracle.py | 1 + sqlglot/dialects/presto.py | 6 +++ sqlglot/dialects/sqlite.py | 5 +++ sqlglot/expressions.py | 21 +++++++--- sqlglot/generator.py | 11 ++++++ sqlglot/parser.py | 69 +++++++++++++++++++++++++++------ sqlglot/tokens.py | 5 +-- tests/dialects/test_bigquery.py | 4 ++ tests/dialects/test_dialect.py | 33 ++++++++++++++++ tests/dialects/test_presto.py | 2 + tests/fixtures/identity.sql | 3 +- tests/test_expressions.py | 3 ++ tests/test_parser.py | 29 ++++++++++++++ 16 files changed, 200 insertions(+), 38 deletions(-) diff --git a/README.md b/README.md index b00b803d22..9f3dedb182 100644 --- a/README.md +++ b/README.md @@ -90,7 +90,7 @@ sqlglot.transpile("SELECT STRFTIME(x, '%y-%-m-%S')", read="duckdb", write="hive" "SELECT DATE_FORMAT(x, 'yy-M-ss')" ``` -As another example, let's suppose that we want to read in a SQL query that contains a CTE and a cast to `REAL`, and then transpile it to Spark, which uses backticks as identifiers and `FLOAT` instead of `REAL`: +As another example, let's suppose that we want to read in a SQL query that contains a CTE and a cast to `REAL`, and then transpile it to Spark, which uses backticks for identifiers and `FLOAT` instead of `REAL`: ```python import sqlglot diff --git a/sqlglot/dialects/bigquery.py b/sqlglot/dialects/bigquery.py index 5bbff9dea4..65e66adfdd 100644 --- a/sqlglot/dialects/bigquery.py +++ b/sqlglot/dialects/bigquery.py @@ -151,6 +151,12 @@ class Parser(parser.Parser): TokenType.TABLE, } + # BEGIN signifies the start of a block statement, so it's different from BEGIN TRANSACTION + def _parse_transaction(self): + if self._match_text_seq("TRANSACTION"): + return self.expression(exp.Transaction) + return self.expression(exp.Command, this=self._prev.text) + class Generator(generator.Generator): TRANSFORMS = { **generator.Generator.TRANSFORMS, @@ -204,6 +210,15 @@ class Generator(generator.Generator): EXPLICIT_UNION = True + def transaction_sql(self, *_): + return "BEGIN TRANSACTION" + + def commit_sql(self, *_): + return "COMMIT TRANSACTION" + + def rollback_sql(self, *_): + return "ROLLBACK TRANSACTION" + def in_unnest_op(self, unnest): return self.sql(unnest) diff --git a/sqlglot/dialects/mysql.py b/sqlglot/dialects/mysql.py index e742640624..30c21b75ce 100644 --- a/sqlglot/dialects/mysql.py +++ b/sqlglot/dialects/mysql.py @@ -120,6 +120,7 @@ class Tokenizer(tokens.Tokenizer): KEYWORDS = { **tokens.Tokenizer.KEYWORDS, + "START": TokenType.BEGIN, "SEPARATOR": TokenType.SEPARATOR, "_ARMSCII8": TokenType.INTRODUCER, "_ASCII": TokenType.INTRODUCER, @@ -281,36 +282,36 @@ class Parser(parser.Parser): def _parse_show_mysql(self, this, target=False, full=None, global_=None): if target: if isinstance(target, str): - self._match_text(target) + self._match_text_seq(target) target_id = self._parse_id_var() else: target_id = None - log = self._parse_string() if self._match_text("IN") else None + log = self._parse_string() if self._match_text_seq("IN") else None if this in {"BINLOG EVENTS", "RELAYLOG EVENTS"}: - position = self._parse_number() if self._match_text("FROM") else None + position = self._parse_number() if self._match_text_seq("FROM") else None db = None else: position = None - db = self._parse_id_var() if self._match_text("FROM") else None + db = self._parse_id_var() if self._match_text_seq("FROM") else None - channel = self._parse_id_var() if self._match_text("FOR", "CHANNEL") else None + channel = self._parse_id_var() if self._match_text_seq("FOR", "CHANNEL") else None - like = self._parse_string() if self._match_text("LIKE") else None + like = self._parse_string() if self._match_text_seq("LIKE") else None where = self._parse_where() if this == "PROFILE": types = self._parse_csv(self._parse_show_profile_type) - query = self._parse_number() if self._match_text("FOR", "QUERY") else None - offset = self._parse_number() if self._match_text("OFFSET") else None - limit = self._parse_number() if self._match_text("LIMIT") else None + query = self._parse_number() if self._match_text_seq("FOR", "QUERY") else None + offset = self._parse_number() if self._match_text_seq("OFFSET") else None + limit = self._parse_number() if self._match_text_seq("LIMIT") else None else: types, query = None, None offset, limit = self._parse_oldstyle_limit() - mutex = True if self._match_text("MUTEX") else None - mutex = False if self._match_text("STATUS") else mutex + mutex = True if self._match_text_seq("MUTEX") else None + mutex = False if self._match_text_seq("STATUS") else mutex return self.expression( exp.Show, @@ -333,14 +334,14 @@ def _parse_show_mysql(self, this, target=False, full=None, global_=None): def _parse_show_profile_type(self): for type_ in self.PROFILE_TYPES: - if self._match_text(*type_.split(" ")): + if self._match_text_seq(*type_.split(" ")): return exp.Var(this=type_) return None def _parse_oldstyle_limit(self): limit = None offset = None - if self._match_text("LIMIT"): + if self._match_text_seq("LIMIT"): parts = self._parse_csv(self._parse_number) if len(parts) == 1: limit = parts[0] @@ -381,7 +382,7 @@ def _parse_set_item_charset(self, kind): def _parse_set_item_names(self): charset = self._parse_string() or self._parse_id_var() - if self._match_text("COLLATE"): + if self._match_text_seq("COLLATE"): collate = self._parse_string() or self._parse_id_var() else: collate = None diff --git a/sqlglot/dialects/oracle.py b/sqlglot/dialects/oracle.py index 3bc1109749..870d2b9594 100644 --- a/sqlglot/dialects/oracle.py +++ b/sqlglot/dialects/oracle.py @@ -91,6 +91,7 @@ def table_sql(self, expression): class Tokenizer(tokens.Tokenizer): KEYWORDS = { **tokens.Tokenizer.KEYWORDS, + "START": TokenType.BEGIN, "TOP": TokenType.TOP, "VARCHAR2": TokenType.VARCHAR, "NVARCHAR2": TokenType.NVARCHAR, diff --git a/sqlglot/dialects/presto.py b/sqlglot/dialects/presto.py index 11ea77859f..3b17f1d9c1 100644 --- a/sqlglot/dialects/presto.py +++ b/sqlglot/dialects/presto.py @@ -115,6 +115,7 @@ class Presto(Dialect): class Tokenizer(tokens.Tokenizer): KEYWORDS = { **tokens.Tokenizer.KEYWORDS, + "START": TokenType.BEGIN, "ROW": TokenType.STRUCT, } @@ -216,3 +217,8 @@ class Generator(generator.Generator): exp.UnixToTime: rename_func("FROM_UNIXTIME"), exp.UnixToTimeStr: lambda self, e: f"CAST(FROM_UNIXTIME({self.sql(e, 'this')}) AS VARCHAR)", } + + def transaction_sql(self, expression): + modes = expression.args.get("modes") + modes = f" {', '.join(modes)}" if modes else "" + return f"START TRANSACTION{modes}" diff --git a/sqlglot/dialects/sqlite.py b/sqlglot/dialects/sqlite.py index 8c9fb76c69..87b98a592b 100644 --- a/sqlglot/dialects/sqlite.py +++ b/sqlglot/dialects/sqlite.py @@ -63,3 +63,8 @@ class Generator(generator.Generator): exp.TableSample: no_tablesample_sql, exp.TryCast: no_trycast_sql, } + + def transaction_sql(self, expression): + this = expression.this + this = f" {this}" if this else "" + return f"BEGIN{this} TRANSACTION" diff --git a/sqlglot/expressions.py b/sqlglot/expressions.py index 57a2c88fe9..1b231539d0 100644 --- a/sqlglot/expressions.py +++ b/sqlglot/expressions.py @@ -2054,16 +2054,25 @@ class Exists(SubqueryPredicate): pass -# Commands to interact with the databases or engines -# These expressions don't truly parse the expression and consume -# whatever exists as a string until the end or a semicolon +# Commands to interact with the databases or engines. For most of the command +# expressions we parse whatever comes after the command's name as a string. class Command(Expression): arg_types = {"this": True, "expression": False} -# Binary Expressions -# (ADD a b) -# (FROM table selects) +class Transaction(Command): + arg_types = {"this": False, "modes": False} + + +class Commit(Command): + arg_types = {} # type: ignore + + +class Rollback(Command): + arg_types = {"savepoint": False} + + +# Binary expressions like (ADD a b) class Binary(Expression): arg_types = {"this": True, "expression": True} diff --git a/sqlglot/generator.py b/sqlglot/generator.py index 11d90731b3..11ae3a1bae 100644 --- a/sqlglot/generator.py +++ b/sqlglot/generator.py @@ -1176,6 +1176,17 @@ def currentdate_sql(self, expression): def command_sql(self, expression): return f"{self.sql(expression, 'this').upper()} {expression.text('expression').strip()}" + def transaction_sql(self, *_): + return "BEGIN" + + def commit_sql(self, *_): + return "COMMIT" + + def rollback_sql(self, expression): + savepoint = expression.args.get("savepoint") + savepoint = f" TO {savepoint}" if savepoint else "" + return f"ROLLBACK{savepoint}" + def distinct_sql(self, expression): this = self.expressions(expression, flat=True) this = f" {this}" if this else "" diff --git a/sqlglot/parser.py b/sqlglot/parser.py index bbea0e573c..69894464c0 100644 --- a/sqlglot/parser.py +++ b/sqlglot/parser.py @@ -392,6 +392,9 @@ class Parser(metaclass=_Parser): TokenType.CACHE: lambda self: self._parse_cache(), TokenType.UNCACHE: lambda self: self._parse_uncache(), TokenType.USE: lambda self: self._parse_use(), + TokenType.BEGIN: lambda self: self._parse_transaction(), + TokenType.COMMIT: lambda self: self._parse_commit_or_rollback(), + TokenType.ROLLBACK: lambda self: self._parse_commit_or_rollback(), } PRIMARY_PARSERS = { @@ -521,6 +524,8 @@ class Parser(metaclass=_Parser): TokenType.SCHEMA, } + TRANSACTION_KIND = {"DEFERRED", "IMMEDIATE", "EXCLUSIVE"} + STRICT_CAST = True __slots__ = ( @@ -930,7 +935,7 @@ def _parse_describe(self): def _parse_insert(self): overwrite = self._match(TokenType.OVERWRITE) local = self._match(TokenType.LOCAL) - if self._match_text("DIRECTORY"): + if self._match_text_seq("DIRECTORY"): this = self.expression( exp.Directory, this=self._parse_var_or_string(), @@ -954,27 +959,27 @@ def _parse_row_format(self): if not self._match_pair(TokenType.ROW, TokenType.FORMAT): return None - self._match_text("DELIMITED") + self._match_text_seq("DELIMITED") kwargs = {} - if self._match_text("FIELDS", "TERMINATED", "BY"): + if self._match_text_seq("FIELDS", "TERMINATED", "BY"): kwargs["fields"] = self._parse_string() - if self._match_text("ESCAPED", "BY"): + if self._match_text_seq("ESCAPED", "BY"): kwargs["escaped"] = self._parse_string() - if self._match_text("COLLECTION", "ITEMS", "TERMINATED", "BY"): + if self._match_text_seq("COLLECTION", "ITEMS", "TERMINATED", "BY"): kwargs["collection_items"] = self._parse_string() - if self._match_text("MAP", "KEYS", "TERMINATED", "BY"): + if self._match_text_seq("MAP", "KEYS", "TERMINATED", "BY"): kwargs["map_keys"] = self._parse_string() - if self._match_text("LINES", "TERMINATED", "BY"): + if self._match_text_seq("LINES", "TERMINATED", "BY"): kwargs["lines"] = self._parse_string() - if self._match_text("NULL", "DEFINED", "AS"): + if self._match_text_seq("NULL", "DEFINED", "AS"): kwargs["null"] = self._parse_string() return self.expression(exp.RowFormat, **kwargs) def _parse_load_data(self): local = self._match(TokenType.LOCAL) - self._match_text("INPATH") + self._match_text_seq("INPATH") inpath = self._parse_string() overwrite = self._match(TokenType.OVERWRITE) self._match_pair(TokenType.INTO, TokenType.TABLE) @@ -986,8 +991,8 @@ def _parse_load_data(self): overwrite=overwrite, inpath=inpath, partition=self._parse_partition(), - input_format=self._match_text("INPUTFORMAT") and self._parse_string(), - serde=self._match_text("SERDE") and self._parse_string(), + input_format=self._match_text_seq("INPUTFORMAT") and self._parse_string(), + serde=self._match_text_seq("SERDE") and self._parse_string(), ) def _parse_delete(self): @@ -2594,6 +2599,40 @@ def _parse_select_or_expression(self): def _parse_use(self): return self.expression(exp.Use, this=self._parse_id_var()) + def _parse_transaction(self): + this = None + if self._match_texts(self.TRANSACTION_KIND): + this = self._prev.text + + self._match_texts({"TRANSACTION", "WORK"}) + + modes = [] + while True: + mode = [] + while self._match(TokenType.VAR): + mode.append(self._prev.text) + + if mode: + modes.append(" ".join(mode)) + if not self._match(TokenType.COMMA): + break + + return self.expression(exp.Transaction, this=this, modes=modes) + + def _parse_commit_or_rollback(self): + savepoint = None + is_rollback = self._prev.token_type == TokenType.ROLLBACK + + self._match_texts({"TRANSACTION", "WORK"}) + + if self._match_text_seq("TO"): + self._match_text_seq("SAVEPOINT") + savepoint = self._parse_id_var() + + if is_rollback: + return self.expression(exp.Rollback, savepoint=savepoint) + return self.expression(exp.Commit) + def _parse_show(self): parser = self._find_parser(self.SHOW_PARSERS, self._show_trie) if parser: @@ -2675,7 +2714,13 @@ def _match_r_paren(self, expression=None): if expression and self._prev_comment: expression.comment = self._prev_comment - def _match_text(self, *texts): + def _match_texts(self, texts): + if self._curr and self._curr.text.upper() in texts: + self._advance() + return True + return False + + def _match_text_seq(self, *texts): index = self._index for text in texts: if self._curr and self._curr.text.upper() == text: diff --git a/sqlglot/tokens.py b/sqlglot/tokens.py index 95d84d6eb9..5d28753ad6 100644 --- a/sqlglot/tokens.py +++ b/sqlglot/tokens.py @@ -670,20 +670,17 @@ class Tokenizer(metaclass=_Tokenizer): } COMMANDS = { - TokenType.ALTER, TokenType.ADD_FILE, + TokenType.ALTER, TokenType.ANALYZE, - TokenType.BEGIN, TokenType.CALL, TokenType.COMMENT_ON, - TokenType.COMMIT, TokenType.EXPLAIN, TokenType.OPTIMIZE, TokenType.SET, TokenType.SHOW, TokenType.TRUNCATE, TokenType.VACUUM, - TokenType.ROLLBACK, } # handle numeric literals like in hive (3L = BIGINT) diff --git a/tests/dialects/test_bigquery.py b/tests/dialects/test_bigquery.py index a0ebc45824..790e5ef34e 100644 --- a/tests/dialects/test_bigquery.py +++ b/tests/dialects/test_bigquery.py @@ -286,6 +286,10 @@ def test_bigquery(self): "bigquery": "SELECT * FROM (SELECT a, b, c FROM test) PIVOT(SUM(b) AS d, COUNT(*) AS e FOR c IN ('x', 'y'))", }, ) + self.validate_identity("BEGIN") + self.validate_identity("BEGIN TRANSACTION") + self.validate_identity("COMMIT TRANSACTION") + self.validate_identity("ROLLBACK TRANSACTION") def test_user_defined_functions(self): self.validate_identity( diff --git a/tests/dialects/test_dialect.py b/tests/dialects/test_dialect.py index 1913f5307c..e24b54eaee 100644 --- a/tests/dialects/test_dialect.py +++ b/tests/dialects/test_dialect.py @@ -1230,3 +1230,36 @@ def test_hash_comments(self): }, pretty=True, ) + + def test_transactions(self): + self.validate_all( + "BEGIN TRANSACTION", + write={ + "bigquery": "BEGIN TRANSACTION", + "mysql": "BEGIN", + "postgres": "BEGIN", + "presto": "START TRANSACTION", + "trino": "START TRANSACTION", + "redshift": "BEGIN", + "snowflake": "BEGIN", + "sqlite": "BEGIN TRANSACTION", + }, + ) + self.validate_all( + "BEGIN", + read={ + "presto": "START TRANSACTION READ WRITE, ISOLATION LEVEL SERIALIZABLE", + "trino": "START TRANSACTION READ WRITE, ISOLATION LEVEL SERIALIZABLE", + }, + ) + self.validate_all( + "BEGIN", + read={ + "presto": "START TRANSACTION ISOLATION LEVEL REPEATABLE READ", + "trino": "START TRANSACTION ISOLATION LEVEL REPEATABLE READ", + }, + ) + self.validate_all( + "BEGIN IMMEDIATE TRANSACTION", + write={"sqlite": "BEGIN IMMEDIATE TRANSACTION"}, + ) diff --git a/tests/dialects/test_presto.py b/tests/dialects/test_presto.py index 098ad2b777..3f9437f8d6 100644 --- a/tests/dialects/test_presto.py +++ b/tests/dialects/test_presto.py @@ -427,3 +427,5 @@ def test_presto(self): "spark": UnsupportedError, }, ) + self.validate_identity("START TRANSACTION READ WRITE, ISOLATION LEVEL SERIALIZABLE") + self.validate_identity("START TRANSACTION ISOLATION LEVEL REPEATABLE READ") diff --git a/tests/fixtures/identity.sql b/tests/fixtures/identity.sql index 836ab28c46..652d9a275b 100644 --- a/tests/fixtures/identity.sql +++ b/tests/fixtures/identity.sql @@ -128,7 +128,6 @@ ADD FILE s3://file ADD FILES s3://file, s3://a ADD ARCHIVE s3://file ADD ARCHIVES s3://file, s3://a -BEGIN IMMEDIATE TRANSACTION COMMIT USE db NOT 1 @@ -524,7 +523,9 @@ DROP VIEW IF EXISTS a DROP VIEW IF EXISTS a.b SHOW TABLES USE db +BEGIN ROLLBACK +ROLLBACK TO b EXPLAIN SELECT * FROM x INSERT INTO x SELECT * FROM y INSERT INTO x (SELECT * FROM y) diff --git a/tests/test_expressions.py b/tests/test_expressions.py index 63371d8fba..493fcc5d65 100644 --- a/tests/test_expressions.py +++ b/tests/test_expressions.py @@ -441,6 +441,9 @@ def test_functions(self): self.assertIsInstance(parse_one("VARIANCE(a)"), exp.Variance) self.assertIsInstance(parse_one("VARIANCE_POP(a)"), exp.VariancePop) self.assertIsInstance(parse_one("YEAR(a)"), exp.Year) + self.assertIsInstance(parse_one("BEGIN DEFERRED TRANSACTION"), exp.Transaction) + self.assertIsInstance(parse_one("COMMIT"), exp.Commit) + self.assertIsInstance(parse_one("ROLLBACK"), exp.Rollback) def test_column(self): dot = parse_one("a.b.c") diff --git a/tests/test_parser.py b/tests/test_parser.py index 04c20b1bbe..aeb518e49c 100644 --- a/tests/test_parser.py +++ b/tests/test_parser.py @@ -47,6 +47,35 @@ def test_command(self): self.assertEqual(expressions[1].sql(), "ADD JAR s3://a") self.assertEqual(expressions[2].sql(), "SELECT 1") + def test_transactions(self): + expression = parse_one("BEGIN TRANSACTION") + self.assertIsNone(expression.this) + self.assertEqual(expression.args["modes"], []) + self.assertEqual(expression.sql(), "BEGIN") + + expression = parse_one("START TRANSACTION", read="mysql") + self.assertIsNone(expression.this) + self.assertEqual(expression.args["modes"], []) + self.assertEqual(expression.sql(), "BEGIN") + + expression = parse_one("BEGIN DEFERRED TRANSACTION") + self.assertEqual(expression.this, "DEFERRED") + self.assertEqual(expression.args["modes"], []) + self.assertEqual(expression.sql(), "BEGIN") + + expression = parse_one( + "START TRANSACTION READ WRITE, ISOLATION LEVEL SERIALIZABLE", read="presto" + ) + self.assertIsNone(expression.this) + self.assertEqual(expression.args["modes"][0], "READ WRITE") + self.assertEqual(expression.args["modes"][1], "ISOLATION LEVEL SERIALIZABLE") + self.assertEqual(expression.sql(), "BEGIN") + + expression = parse_one("BEGIN", read="bigquery") + self.assertNotIsInstance(expression, exp.Transaction) + self.assertIsNone(expression.expression) + self.assertEqual(expression.sql(), "BEGIN") + def test_identify(self): expression = parse_one( """