diff --git a/spanner/google/cloud/spanner_v1/database.py b/spanner/google/cloud/spanner_v1/database.py index d3494eb63902..93960947de20 100644 --- a/spanner/google/cloud/spanner_v1/database.py +++ b/spanner/google/cloud/spanner_v1/database.py @@ -14,17 +14,20 @@ """User friendly container for Cloud Spanner Database.""" +import copy +import functools import re import threading -import copy from google.api_core.gapic_v1 import client_info import google.auth.credentials +from google.protobuf.struct_pb2 import Struct from google.cloud.exceptions import NotFound import six # pylint: disable=ungrouped-imports from google.cloud.spanner_v1 import __version__ +from google.cloud.spanner_v1._helpers import _make_value_pb from google.cloud.spanner_v1._helpers import _metadata_with_prefix from google.cloud.spanner_v1.batch import Batch from google.cloud.spanner_v1.gapic.spanner_client import SpannerClient @@ -32,7 +35,11 @@ from google.cloud.spanner_v1.pool import BurstyPool from google.cloud.spanner_v1.pool import SessionCheckout from google.cloud.spanner_v1.session import Session +from google.cloud.spanner_v1.snapshot import _restart_on_unavailable from google.cloud.spanner_v1.snapshot import Snapshot +from google.cloud.spanner_v1.streamed import StreamedResultSet +from google.cloud.spanner_v1.proto.transaction_pb2 import ( + TransactionSelector, TransactionOptions) # pylint: enable=ungrouped-imports @@ -272,6 +279,70 @@ def drop(self): metadata = _metadata_with_prefix(self.name) api.drop_database(self.name, metadata=metadata) + def execute_partitioned_dml( + self, dml, params=None, param_types=None, query_mode=None): + """Execute a partitionable DML statement. + + :type dml: str + :param dml: SQL DML statement + + :type params: dict, {str -> column value} + :param params: values for parameter replacement. Keys must match + the names used in ``dml``. + + :type param_types: dict[str -> Union[dict, .types.Type]] + :param param_types: + (Optional) maps explicit types for one or more param values; + required if parameters are passed. + + :type query_mode: + :class:`google.cloud.spanner_v1.proto.ExecuteSqlRequest.QueryMode` + :param query_mode: Mode governing return of results / query plan. See + https://cloud.google.com/spanner/reference/rpc/google.spanner.v1#google.spanner.v1.ExecuteSqlRequest.QueryMode1 + + :rtype: int + :returns: Count of rows affected by the DML statement. + """ + if params is not None: + if param_types is None: + raise ValueError( + "Specify 'param_types' when passing 'params'.") + params_pb = Struct(fields={ + key: _make_value_pb(value) for key, value in params.items()}) + else: + params_pb = None + + api = self.spanner_api + + txn_options = TransactionOptions( + partitioned_dml=TransactionOptions.PartitionedDml()) + + metadata = _metadata_with_prefix(self.name) + + with SessionCheckout(self._pool) as session: + + txn = api.begin_transaction( + session.name, txn_options, metadata=metadata) + + txn_selector = TransactionSelector(id=txn.id) + + restart = functools.partial( + api.execute_streaming_sql, + session.name, + dml, + transaction=txn_selector, + params=params_pb, + param_types=param_types, + query_mode=query_mode, + metadata=metadata) + + iterator = _restart_on_unavailable(restart) + + result_set = StreamedResultSet(iterator) + list(result_set) # consume all partials + + return result_set.stats.row_count_lower_bound + def session(self, labels=None): """Factory to create a session for this database. diff --git a/spanner/google/cloud/spanner_v1/transaction.py b/spanner/google/cloud/spanner_v1/transaction.py index 18b87f5cc383..cc2f06cee54d 100644 --- a/spanner/google/cloud/spanner_v1/transaction.py +++ b/spanner/google/cloud/spanner_v1/transaction.py @@ -149,11 +149,8 @@ def execute_update(self, dml, params=None, param_types=None, :param query_mode: Mode governing return of results / query plan. See https://cloud.google.com/spanner/reference/rpc/google.spanner.v1#google.spanner.v1.ExecuteSqlRequest.QueryMode1 - :rtype: - :class:`google.cloud.spanner_v1.proto.ExecuteSqlRequest.ResultSetStats` - :returns: - stats object, including count of rows affected by the DML - statement. + :rtype: int + :returns: Count of rows affected by the DML statement. """ if params is not None: if param_types is None: diff --git a/spanner/tests/system/test_system.py b/spanner/tests/system/test_system.py index b2a99cf45c31..228cd7849fa0 100644 --- a/spanner/tests/system/test_system.py +++ b/spanner/tests/system/test_system.py @@ -730,6 +730,62 @@ def test_transaction_execute_update_then_insert_commit(self): rows = list(session.read(self.TABLE, self.COLUMNS, self.ALL)) self._check_rows_data(rows) + def test_execute_partitioned_dml(self): + retry = RetryInstanceState(_has_all_ddl) + retry(self._db.reload)() + + delete_statement = 'DELETE FROM {} WHERE true'.format(self.TABLE) + + def _setup_table(txn): + txn.execute_update(delete_statement) + for insert_statement in self._generate_insert_statements(): + txn.execute_update(insert_statement) + + committed = self._db.run_in_transaction(_setup_table) + + with self._db.snapshot(read_timestamp=committed) as snapshot: + before_pdml = list(snapshot.read( + self.TABLE, self.COLUMNS, self.ALL)) + + self._check_rows_data(before_pdml) + + nonesuch = 'nonesuch@example.com' + target = 'phred@example.com' + update_statement = ( + 'UPDATE {table} SET {table}.email = @email ' + 'WHERE {table}.email = @target').format( + table=self.TABLE) + + row_count = self._db.execute_partitioned_dml( + update_statement, + params={ + 'email': nonesuch, + 'target': target, + }, + param_types={ + 'email': Type(code=STRING), + 'target': Type(code=STRING), + }, + ) + self.assertEqual(row_count, 1) + + row = self.ROW_DATA[0] + updated = [row[:3] + (nonesuch,)] + list(self.ROW_DATA[1:]) + + with self._db.snapshot(read_timestamp=committed) as snapshot: + after_update = list(snapshot.read( + self.TABLE, self.COLUMNS, self.ALL)) + self._check_rows_data(after_update, updated) + + row_count = self._db.execute_partitioned_dml(delete_statement) + self.assertEqual(row_count, len(self.ROW_DATA)) + + with self._db.snapshot(read_timestamp=committed) as snapshot: + after_delete = list(snapshot.read( + self.TABLE, self.COLUMNS, self.ALL)) + + self._check_rows_data(after_delete, []) + def _transaction_concurrency_helper(self, unit_of_work, pkey): INITIAL_VALUE = 123 NUM_THREADS = 3 # conforms to equivalent Java systest. diff --git a/spanner/tests/unit/test_database.py b/spanner/tests/unit/test_database.py index 34b30deb2022..c17251647511 100644 --- a/spanner/tests/unit/test_database.py +++ b/spanner/tests/unit/test_database.py @@ -18,6 +18,19 @@ import mock +DML_WO_PARAM = """ +DELETE FROM citizens +""" + +DML_W_PARAM = """ +INSERT INTO citizens(first_name, last_name, age) +VALUES ("Phred", "Phlyntstone", @age) +""" +PARAMS = {'age': 30} +PARAM_TYPES = {'age': 'INT64'} +MODE = 2 # PROFILE + + def _make_credentials(): # pragma: NO COVER import google.auth.credentials @@ -39,7 +52,7 @@ class _BaseTest(unittest.TestCase): DATABASE_NAME = INSTANCE_NAME + '/databases/' + DATABASE_ID SESSION_ID = 'session_id' SESSION_NAME = DATABASE_NAME + '/sessions/' + SESSION_ID - TRANSACTION_ID = 'transaction_id' + TRANSACTION_ID = b'transaction_id' def _make_one(self, *args, **kwargs): return self._get_target_class()(*args, **kwargs) @@ -65,6 +78,20 @@ def _get_target_class(self): return Database + @staticmethod + def _make_database_admin_api(): + from google.cloud.spanner_v1.client import DatabaseAdminClient + + return mock.create_autospec(DatabaseAdminClient, instance=True) + + @staticmethod + def _make_spanner_api(): + import google.cloud.spanner_v1.gapic.spanner_client + + return mock.create_autospec( + google.cloud.spanner_v1.gapic.spanner_client.SpannerClient, + instance=True) + def test_ctor_defaults(self): from google.cloud.spanner_v1.pool import BurstyPool @@ -296,10 +323,12 @@ def test___ne__(self): def test_create_grpc_error(self): from google.api_core.exceptions import GoogleAPICallError + from google.api_core.exceptions import Unknown client = _Client() - api = client.database_admin_api = _FauxDatabaseAdminAPI( - _rpc_error=True) + api = client.database_admin_api = self._make_database_admin_api() + api.create_database.side_effect = Unknown('testing') + instance = _Instance(self.INSTANCE_NAME, client=client) pool = _Pool() database = self._make_one(self.DATABASE_ID, instance, pool=pool) @@ -307,22 +336,20 @@ def test_create_grpc_error(self): with self.assertRaises(GoogleAPICallError): database.create() - (parent, create_statement, extra_statements, - metadata) = api._created_database - self.assertEqual(parent, self.INSTANCE_NAME) - self.assertEqual(create_statement, - 'CREATE DATABASE %s' % self.DATABASE_ID) - self.assertEqual(extra_statements, []) - self.assertEqual( - metadata, [('google-cloud-resource-prefix', database.name)]) + api.create_database.assert_called_once_with( + parent=self.INSTANCE_NAME, + create_statement='CREATE DATABASE {}'.format(self.DATABASE_ID), + extra_statements=[], + metadata=[('google-cloud-resource-prefix', database.name)], + ) def test_create_already_exists(self): from google.cloud.exceptions import Conflict DATABASE_ID_HYPHEN = 'database-id' client = _Client() - api = client.database_admin_api = _FauxDatabaseAdminAPI( - _create_database_conflict=True) + api = client.database_admin_api = self._make_database_admin_api() + api.create_database.side_effect = Conflict('testing') instance = _Instance(self.INSTANCE_NAME, client=client) pool = _Pool() database = self._make_one(DATABASE_ID_HYPHEN, instance, pool=pool) @@ -330,45 +357,40 @@ def test_create_already_exists(self): with self.assertRaises(Conflict): database.create() - (parent, create_statement, extra_statements, - metadata) = api._created_database - self.assertEqual(parent, self.INSTANCE_NAME) - self.assertEqual(create_statement, - 'CREATE DATABASE `%s`' % DATABASE_ID_HYPHEN) - self.assertEqual(extra_statements, []) - self.assertEqual( - metadata, [('google-cloud-resource-prefix', database.name)]) + api.create_database.assert_called_once_with( + parent=self.INSTANCE_NAME, + create_statement='CREATE DATABASE `{}`'.format(DATABASE_ID_HYPHEN), + extra_statements=[], + metadata=[('google-cloud-resource-prefix', database.name)], + ) def test_create_instance_not_found(self): from google.cloud.exceptions import NotFound - DATABASE_ID_HYPHEN = 'database-id' client = _Client() - api = client.database_admin_api = _FauxDatabaseAdminAPI( - _database_not_found=True) + api = client.database_admin_api = self._make_database_admin_api() + api.create_database.side_effect = NotFound('testing') instance = _Instance(self.INSTANCE_NAME, client=client) pool = _Pool() - database = self._make_one(DATABASE_ID_HYPHEN, instance, pool=pool) + database = self._make_one(self.DATABASE_ID, instance, pool=pool) with self.assertRaises(NotFound): database.create() - (parent, create_statement, extra_statements, - metadata) = api._created_database - self.assertEqual(parent, self.INSTANCE_NAME) - self.assertEqual(create_statement, - 'CREATE DATABASE `%s`' % DATABASE_ID_HYPHEN) - self.assertEqual(extra_statements, []) - self.assertEqual( - metadata, [('google-cloud-resource-prefix', database.name)]) + api.create_database.assert_called_once_with( + parent=self.INSTANCE_NAME, + create_statement='CREATE DATABASE {}'.format(self.DATABASE_ID), + extra_statements=[], + metadata=[('google-cloud-resource-prefix', database.name)], + ) def test_create_success(self): from tests._fixtures import DDL_STATEMENTS - op_future = _FauxOperationFuture() + op_future = object() client = _Client() - api = client.database_admin_api = _FauxDatabaseAdminAPI( - _create_database_response=op_future) + api = client.database_admin_api = self._make_database_admin_api() + api.create_database.return_value = op_future instance = _Instance(self.INSTANCE_NAME, client=client) pool = _Pool() database = self._make_one( @@ -379,21 +401,19 @@ def test_create_success(self): self.assertIs(future, op_future) - (parent, create_statement, extra_statements, - metadata) = api._created_database - self.assertEqual(parent, self.INSTANCE_NAME) - self.assertEqual(create_statement, - 'CREATE DATABASE %s' % self.DATABASE_ID) - self.assertEqual(extra_statements, DDL_STATEMENTS) - self.assertEqual( - metadata, [('google-cloud-resource-prefix', database.name)]) + api.create_database.assert_called_once_with( + parent=self.INSTANCE_NAME, + create_statement='CREATE DATABASE {}'.format(self.DATABASE_ID), + extra_statements=DDL_STATEMENTS, + metadata=[('google-cloud-resource-prefix', database.name)], + ) def test_exists_grpc_error(self): from google.api_core.exceptions import Unknown client = _Client() - client.database_admin_api = _FauxDatabaseAdminAPI( - _rpc_error=True) + api = client.database_admin_api = self._make_database_admin_api() + api.get_database_ddl.side_effect = Unknown('testing') instance = _Instance(self.INSTANCE_NAME, client=client) pool = _Pool() database = self._make_one(self.DATABASE_ID, instance, pool=pool) @@ -401,20 +421,27 @@ def test_exists_grpc_error(self): with self.assertRaises(Unknown): database.exists() + api.get_database_ddl.assert_called_once_with( + self.DATABASE_NAME, + metadata=[('google-cloud-resource-prefix', database.name)], + ) + def test_exists_not_found(self): + from google.cloud.exceptions import NotFound + client = _Client() - api = client.database_admin_api = _FauxDatabaseAdminAPI( - _database_not_found=True) + api = client.database_admin_api = self._make_database_admin_api() + api.get_database_ddl.side_effect = NotFound('testing') instance = _Instance(self.INSTANCE_NAME, client=client) pool = _Pool() database = self._make_one(self.DATABASE_ID, instance, pool=pool) self.assertFalse(database.exists()) - name, metadata = api._got_database_ddl - self.assertEqual(name, self.DATABASE_NAME) - self.assertEqual( - metadata, [('google-cloud-resource-prefix', database.name)]) + api.get_database_ddl.assert_called_once_with( + self.DATABASE_NAME, + metadata=[('google-cloud-resource-prefix', database.name)], + ) def test_exists_success(self): from google.cloud.spanner_admin_database_v1.proto import ( @@ -424,25 +451,25 @@ def test_exists_success(self): client = _Client() ddl_pb = admin_v1_pb2.GetDatabaseDdlResponse( statements=DDL_STATEMENTS) - api = client.database_admin_api = _FauxDatabaseAdminAPI( - _get_database_ddl_response=ddl_pb) + api = client.database_admin_api = self._make_database_admin_api() + api.get_database_ddl.return_value = ddl_pb instance = _Instance(self.INSTANCE_NAME, client=client) pool = _Pool() database = self._make_one(self.DATABASE_ID, instance, pool=pool) self.assertTrue(database.exists()) - name, metadata = api._got_database_ddl - self.assertEqual(name, self.DATABASE_NAME) - self.assertEqual( - metadata, [('google-cloud-resource-prefix', database.name)]) + api.get_database_ddl.assert_called_once_with( + self.DATABASE_NAME, + metadata=[('google-cloud-resource-prefix', database.name)], + ) def test_reload_grpc_error(self): from google.api_core.exceptions import Unknown client = _Client() - client.database_admin_api = _FauxDatabaseAdminAPI( - _rpc_error=True) + api = client.database_admin_api = self._make_database_admin_api() + api.get_database_ddl.side_effect = Unknown('testing') instance = _Instance(self.INSTANCE_NAME, client=client) pool = _Pool() database = self._make_one(self.DATABASE_ID, instance, pool=pool) @@ -450,12 +477,17 @@ def test_reload_grpc_error(self): with self.assertRaises(Unknown): database.reload() + api.get_database_ddl.assert_called_once_with( + self.DATABASE_NAME, + metadata=[('google-cloud-resource-prefix', database.name)], + ) + def test_reload_not_found(self): from google.cloud.exceptions import NotFound client = _Client() - api = client.database_admin_api = _FauxDatabaseAdminAPI( - _database_not_found=True) + api = client.database_admin_api = self._make_database_admin_api() + api.get_database_ddl.side_effect = NotFound('testing') instance = _Instance(self.INSTANCE_NAME, client=client) pool = _Pool() database = self._make_one(self.DATABASE_ID, instance, pool=pool) @@ -463,10 +495,10 @@ def test_reload_not_found(self): with self.assertRaises(NotFound): database.reload() - name, metadata = api._got_database_ddl - self.assertEqual(name, self.DATABASE_NAME) - self.assertEqual( - metadata, [('google-cloud-resource-prefix', database.name)]) + api.get_database_ddl.assert_called_once_with( + self.DATABASE_NAME, + metadata=[('google-cloud-resource-prefix', database.name)], + ) def test_reload_success(self): from google.cloud.spanner_admin_database_v1.proto import ( @@ -476,8 +508,8 @@ def test_reload_success(self): client = _Client() ddl_pb = admin_v1_pb2.GetDatabaseDdlResponse( statements=DDL_STATEMENTS) - api = client.database_admin_api = _FauxDatabaseAdminAPI( - _get_database_ddl_response=ddl_pb) + api = client.database_admin_api = self._make_database_admin_api() + api.get_database_ddl.return_value = ddl_pb instance = _Instance(self.INSTANCE_NAME, client=client) pool = _Pool() database = self._make_one(self.DATABASE_ID, instance, pool=pool) @@ -486,18 +518,18 @@ def test_reload_success(self): self.assertEqual(database._ddl_statements, tuple(DDL_STATEMENTS)) - name, metadata = api._got_database_ddl - self.assertEqual(name, self.DATABASE_NAME) - self.assertEqual( - metadata, [('google-cloud-resource-prefix', database.name)]) + api.get_database_ddl.assert_called_once_with( + self.DATABASE_NAME, + metadata=[('google-cloud-resource-prefix', database.name)], + ) def test_update_ddl_grpc_error(self): from google.api_core.exceptions import Unknown from tests._fixtures import DDL_STATEMENTS client = _Client() - client.database_admin_api = _FauxDatabaseAdminAPI( - _rpc_error=True) + api = client.database_admin_api = self._make_database_admin_api() + api.update_database_ddl.side_effect = Unknown('testing') instance = _Instance(self.INSTANCE_NAME, client=client) pool = _Pool() database = self._make_one(self.DATABASE_ID, instance, pool=pool) @@ -505,13 +537,20 @@ def test_update_ddl_grpc_error(self): with self.assertRaises(Unknown): database.update_ddl(DDL_STATEMENTS) + api.update_database_ddl.assert_called_once_with( + self.DATABASE_NAME, + DDL_STATEMENTS, + '', + metadata=[('google-cloud-resource-prefix', database.name)], + ) + def test_update_ddl_not_found(self): from google.cloud.exceptions import NotFound from tests._fixtures import DDL_STATEMENTS client = _Client() - api = client.database_admin_api = _FauxDatabaseAdminAPI( - _database_not_found=True) + api = client.database_admin_api = self._make_database_admin_api() + api.update_database_ddl.side_effect = NotFound('testing') instance = _Instance(self.INSTANCE_NAME, client=client) pool = _Pool() database = self._make_one(self.DATABASE_ID, instance, pool=pool) @@ -519,20 +558,20 @@ def test_update_ddl_not_found(self): with self.assertRaises(NotFound): database.update_ddl(DDL_STATEMENTS) - name, statements, op_id, metadata = api._updated_database_ddl - self.assertEqual(name, self.DATABASE_NAME) - self.assertEqual(statements, DDL_STATEMENTS) - self.assertEqual(op_id, '') - self.assertEqual( - metadata, [('google-cloud-resource-prefix', database.name)]) + api.update_database_ddl.assert_called_once_with( + self.DATABASE_NAME, + DDL_STATEMENTS, + '', + metadata=[('google-cloud-resource-prefix', database.name)], + ) def test_update_ddl(self): from tests._fixtures import DDL_STATEMENTS - op_future = _FauxOperationFuture() + op_future = object() client = _Client() - api = client.database_admin_api = _FauxDatabaseAdminAPI( - _update_database_ddl_response=op_future) + api = client.database_admin_api = self._make_database_admin_api() + api.update_database_ddl.return_value = op_future instance = _Instance(self.INSTANCE_NAME, client=client) pool = _Pool() database = self._make_one(self.DATABASE_ID, instance, pool=pool) @@ -541,19 +580,19 @@ def test_update_ddl(self): self.assertIs(future, op_future) - name, statements, op_id, metadata = api._updated_database_ddl - self.assertEqual(name, self.DATABASE_NAME) - self.assertEqual(statements, DDL_STATEMENTS) - self.assertEqual(op_id, '') - self.assertEqual( - metadata, [('google-cloud-resource-prefix', database.name)]) + api.update_database_ddl.assert_called_once_with( + self.DATABASE_NAME, + DDL_STATEMENTS, + '', + metadata=[('google-cloud-resource-prefix', database.name)], + ) def test_drop_grpc_error(self): from google.api_core.exceptions import Unknown client = _Client() - client.database_admin_api = _FauxDatabaseAdminAPI( - _rpc_error=True) + api = client.database_admin_api = self._make_database_admin_api() + api.drop_database.side_effect = Unknown('testing') instance = _Instance(self.INSTANCE_NAME, client=client) pool = _Pool() database = self._make_one(self.DATABASE_ID, instance, pool=pool) @@ -561,12 +600,17 @@ def test_drop_grpc_error(self): with self.assertRaises(Unknown): database.drop() + api.drop_database.assert_called_once_with( + self.DATABASE_NAME, + metadata=[('google-cloud-resource-prefix', database.name)], + ) + def test_drop_not_found(self): from google.cloud.exceptions import NotFound client = _Client() - api = client.database_admin_api = _FauxDatabaseAdminAPI( - _database_not_found=True) + api = client.database_admin_api = self._make_database_admin_api() + api.drop_database.side_effect = NotFound('testing') instance = _Instance(self.INSTANCE_NAME, client=client) pool = _Pool() database = self._make_one(self.DATABASE_ID, instance, pool=pool) @@ -574,27 +618,99 @@ def test_drop_not_found(self): with self.assertRaises(NotFound): database.drop() - name, metadata = api._dropped_database - self.assertEqual(name, self.DATABASE_NAME) - self.assertEqual( - metadata, [('google-cloud-resource-prefix', database.name)]) + api.drop_database.assert_called_once_with( + self.DATABASE_NAME, + metadata=[('google-cloud-resource-prefix', database.name)], + ) def test_drop_success(self): from google.protobuf.empty_pb2 import Empty client = _Client() - api = client.database_admin_api = _FauxDatabaseAdminAPI( - _drop_database_response=Empty()) + api = client.database_admin_api = self._make_database_admin_api() + api.drop_database.return_value = Empty() instance = _Instance(self.INSTANCE_NAME, client=client) pool = _Pool() database = self._make_one(self.DATABASE_ID, instance, pool=pool) database.drop() - name, metadata = api._dropped_database - self.assertEqual(name, self.DATABASE_NAME) - self.assertEqual( - metadata, [('google-cloud-resource-prefix', database.name)]) + api.drop_database.assert_called_once_with( + self.DATABASE_NAME, + metadata=[('google-cloud-resource-prefix', database.name)], + ) + + def _execute_partitioned_dml_helper( + self, dml, params=None, param_types=None): + from google.protobuf.struct_pb2 import Struct + from google.cloud.spanner_v1.proto.result_set_pb2 import ( + PartialResultSet, ResultSetStats) + from google.cloud.spanner_v1.proto.transaction_pb2 import ( + Transaction as TransactionPB, + TransactionSelector, TransactionOptions) + from google.cloud.spanner_v1._helpers import _make_value_pb + + transaction_pb = TransactionPB(id=self.TRANSACTION_ID) + + stats_pb = ResultSetStats(row_count_lower_bound=2) + result_sets = [ + PartialResultSet(stats=stats_pb), + ] + iterator = _MockIterator(*result_sets) + + client = _Client() + instance = _Instance(self.INSTANCE_NAME, client=client) + pool = _Pool() + session = _Session() + pool.put(session) + database = self._make_one(self.DATABASE_ID, instance, pool=pool) + api = database._spanner_api = self._make_spanner_api() + api.begin_transaction.return_value = transaction_pb + api.execute_streaming_sql.return_value = iterator + + row_count = database.execute_partitioned_dml( + dml, params, param_types, query_mode=MODE) + + self.assertEqual(row_count, 2) + + txn_options = TransactionOptions( + partitioned_dml=TransactionOptions.PartitionedDml()) + + api.begin_transaction.assert_called_once_with( + session.name, + txn_options, + metadata=[('google-cloud-resource-prefix', database.name)], + ) + + if params: + expected_params = Struct(fields={ + key: _make_value_pb(value) for (key, value) in params.items()}) + else: + expected_params = None + + expected_transaction = TransactionSelector(id=self.TRANSACTION_ID) + + api.execute_streaming_sql.assert_called_once_with( + self.SESSION_NAME, + dml, + transaction=expected_transaction, + params=expected_params, + param_types=param_types, + query_mode=MODE, + metadata=[('google-cloud-resource-prefix', database.name)], + ) + + def test_execute_partitioned_dml_wo_params(self): + self._execute_partitioned_dml_helper(dml=DML_WO_PARAM) + + def test_execute_partitioned_dml_w_params_wo_param_types(self): + with self.assertRaises(ValueError): + self._execute_partitioned_dml_helper( + dml=DML_W_PARAM, params=PARAMS) + + def test_execute_partitioned_dml_w_params_and_param_types(self): + self._execute_partitioned_dml_helper( + dml=DML_W_PARAM, params=PARAMS, param_types=PARAM_TYPES) def test_session_factory_defaults(self): from google.cloud.spanner_v1.session import Session @@ -787,6 +903,12 @@ def _get_target_class(self): return BatchCheckout + @staticmethod + def _make_spanner_client(): + from google.cloud.spanner_v1.gapic.spanner_client import SpannerClient + + return mock.create_autospec(SpannerClient) + def test_ctor(self): database = _Database(self.DATABASE_NAME) checkout = self._make_one(database) @@ -805,8 +927,8 @@ def test_context_mgr_success(self): now_pb = _datetime_to_pb_timestamp(now) response = CommitResponse(commit_timestamp=now_pb) database = _Database(self.DATABASE_NAME) - api = database.spanner_api = _FauxSpannerClient() - api._commit_response = response + api = database.spanner_api = self._make_spanner_client() + api.commit.return_value = response pool = database._pool = _Pool() session = _Session(database) pool.put(session) @@ -819,14 +941,15 @@ def test_context_mgr_success(self): self.assertIs(pool._session, session) self.assertEqual(batch.committed, now) - (session_name, mutations, single_use_txn, - metadata) = api._committed - self.assertIs(session_name, self.SESSION_NAME) - self.assertEqual(mutations, []) - self.assertIsInstance(single_use_txn, TransactionOptions) - self.assertTrue(single_use_txn.HasField('read_write')) - self.assertEqual( - metadata, [('google-cloud-resource-prefix', database.name)]) + + expected_txn_options = TransactionOptions(read_write={}) + + api.commit.assert_called_once_with( + self.SESSION_NAME, + [], + single_use_transaction=expected_txn_options, + metadata=[('google-cloud-resource-prefix', database.name)], + ) def test_context_mgr_failure(self): from google.cloud.spanner_v1.batch import Batch @@ -1433,80 +1556,19 @@ def run_in_transaction(self, func, *args, **kw): return self._committed -class _SessionPB(object): - name = TestDatabase.SESSION_NAME - - -class _FauxOperationFuture(object): - pass - - -class _FauxSpannerClient(object): - - _committed = None - - def __init__(self, **kwargs): - self.__dict__.update(**kwargs) - - def commit(self, session, mutations, - transaction_id='', single_use_transaction=None, metadata=None): - assert transaction_id == '' - self._committed = ( - session, mutations, single_use_transaction, metadata) - return self._commit_response - - -class _FauxDatabaseAdminAPI(object): - - _create_database_conflict = False - _database_not_found = False - _rpc_error = False - - def __init__(self, **kwargs): - self.__dict__.update(**kwargs) - - def create_database(self, parent, create_statement, extra_statements=None, - metadata=None): - from google.api_core.exceptions import AlreadyExists, NotFound, Unknown - - self._created_database = ( - parent, create_statement, extra_statements, metadata) - if self._rpc_error: - raise Unknown('error') - if self._create_database_conflict: - raise AlreadyExists('conflict') - if self._database_not_found: - raise NotFound('not found') - return self._create_database_response - - def get_database_ddl(self, database, metadata=None): - from google.api_core.exceptions import NotFound, Unknown - - self._got_database_ddl = database, metadata - if self._rpc_error: - raise Unknown('error') - if self._database_not_found: - raise NotFound('not found') - return self._get_database_ddl_response +class _MockIterator(object): - def drop_database(self, database, metadata=None): - from google.api_core.exceptions import NotFound, Unknown + def __init__(self, *values, **kw): + self._iter_values = iter(values) + self._fail_after = kw.pop('fail_after', False) - self._dropped_database = database, metadata - if self._rpc_error: - raise Unknown('error') - if self._database_not_found: - raise NotFound('not found') - return self._drop_database_response + def __iter__(self): + return self - def update_database_ddl(self, database, statements, operation_id, - metadata=None): - from google.api_core.exceptions import NotFound, Unknown + def __next__(self): + try: + return next(self._iter_values) + except StopIteration: + raise - self._updated_database_ddl = ( - database, statements, operation_id, metadata) - if self._rpc_error: - raise Unknown('error') - if self._database_not_found: - raise NotFound('not found') - return self._update_database_ddl_response + next = __next__