diff --git a/spanner/tests/system/test_system.py b/spanner/tests/system/test_system.py index 65ee553806ff..858887e78679 100644 --- a/spanner/tests/system/test_system.py +++ b/spanner/tests/system/test_system.py @@ -641,7 +641,42 @@ def _generate_insert_statements(self): @RetryErrors(exception=exceptions.ServerError) @RetryErrors(exception=exceptions.Conflict) - def test_transaction_execute_sql_w_dml_read_commit(self): + def test_transaction_execute_sql_w_dml_read_rollback(self): + retry = RetryInstanceState(_has_all_ddl) + retry(self._db.reload)() + + session = self._db.session() + session.create() + self.to_delete.append(session) + + with session.batch() as batch: + batch.delete(self.TABLE, self.ALL) + + transaction = session.transaction() + transaction.begin() + + rows = list( + transaction.read(self.TABLE, self.COLUMNS, self.ALL)) + self.assertEqual(rows, []) + + for insert_statement in self._generate_insert_statements(): + result = transaction.execute_sql(insert_statement) + list(result) # iterate to get stats + self.assertEqual(result.stats.row_count_exact, 1) + + # Rows inserted via DML *can* be read before commit. + during_rows = list( + transaction.read(self.TABLE, self.COLUMNS, self.ALL)) + self._check_rows_data(during_rows) + + transaction.rollback() + + rows = list(session.read(self.TABLE, self.COLUMNS, self.ALL)) + self._check_rows_data(rows, []) + + @RetryErrors(exception=exceptions.ServerError) + @RetryErrors(exception=exceptions.Conflict) + def test_transaction_execute_update_read_commit(self): retry = RetryInstanceState(_has_all_ddl) retry(self._db.reload)() @@ -657,9 +692,8 @@ def test_transaction_execute_sql_w_dml_read_commit(self): self.assertEqual(rows, []) for insert_statement in self._generate_insert_statements(): - result = transaction.execute_sql(insert_statement) - list(result) # iterate to get stats - self.assertEqual(result.stats.row_count_exact, 1) + result = transaction.execute_update(insert_statement) + self.assertEqual(result.row_count_exact, 1) # Rows inserted via DML *can* be read before commit. during_rows = list( @@ -671,7 +705,7 @@ def test_transaction_execute_sql_w_dml_read_commit(self): @RetryErrors(exception=exceptions.ServerError) @RetryErrors(exception=exceptions.Conflict) - def test_transaction_execute_update_read_commit(self): + def test_transaction_execute_update_then_insert_commit(self): retry = RetryInstanceState(_has_all_ddl) retry(self._db.reload)() @@ -682,18 +716,16 @@ def test_transaction_execute_update_read_commit(self): with session.batch() as batch: batch.delete(self.TABLE, self.ALL) + insert_statement = list(self._generate_insert_statements())[0] + with session.transaction() as transaction: rows = list(transaction.read(self.TABLE, self.COLUMNS, self.ALL)) self.assertEqual(rows, []) - for insert_statement in self._generate_insert_statements(): - result = transaction.execute_update(insert_statement) - self.assertEqual(result.row_count_exact, 1) + result = transaction.execute_update(insert_statement) + self.assertEqual(result.row_count_exact, 1) - # Rows inserted via DML *can* be read before commit. - during_rows = list( - transaction.read(self.TABLE, self.COLUMNS, self.ALL)) - self._check_rows_data(during_rows) + transaction.insert(self.TABLE, self.COLUMNS, self.ROW_DATA[1:]) rows = list(session.read(self.TABLE, self.COLUMNS, self.ALL)) self._check_rows_data(rows)