Skip to content

Commit

Permalink
Return just row count from 'Transaction.execute_update'.
Browse files Browse the repository at this point in the history
Also, drop the 'partition' argument to it:  not appropriate to the
usecase.
  • Loading branch information
tseaver committed Sep 21, 2018
1 parent a71e8fb commit 08a3399
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 25 deletions.
9 changes: 2 additions & 7 deletions spanner/google/cloud/spanner_v1/transaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def commit(self):
return self.committed

def execute_update(self, dml, params=None, param_types=None,
query_mode=None, partition=None):
query_mode=None):
"""Perform an ``ExecuteSql`` API request with DML.
:type dml: str
Expand All @@ -149,10 +149,6 @@ def execute_update(self, dml, params=None, param_types=None,
:param query_mode: Mode governing return of results / query plan. See
https://cloud.google.com/spanner/reference/rpc/google.spanner.v1#google.spanner.v1.ExecuteSqlRequest.QueryMode1
:type partition: bytes
:param partition: (Optional) one of the partition tokens returned
from :meth:`partition_query`.
:rtype:
:class:`google.cloud.spanner_v1.proto.ExecuteSqlRequest.ResultSetStats`
:returns:
Expand Down Expand Up @@ -180,13 +176,12 @@ def execute_update(self, dml, params=None, param_types=None,
params=params_pb,
param_types=param_types,
query_mode=query_mode,
partition_token=partition,
seqno=self._execute_sql_count,
metadata=metadata,
)

self._execute_sql_count += 1
return response.stats
return response.stats.row_count_exact

def __enter__(self):
"""Begin ``with`` block."""
Expand Down
8 changes: 4 additions & 4 deletions spanner/tests/system/test_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -692,8 +692,8 @@ def test_transaction_execute_update_read_commit(self):
self.assertEqual(rows, [])

for insert_statement in self._generate_insert_statements():
result = transaction.execute_update(insert_statement)
self.assertEqual(result.row_count_exact, 1)
row_count = transaction.execute_update(insert_statement)
self.assertEqual(row_count, 1)

# Rows inserted via DML *can* be read before commit.
during_rows = list(
Expand Down Expand Up @@ -722,8 +722,8 @@ def test_transaction_execute_update_then_insert_commit(self):
rows = list(transaction.read(self.TABLE, self.COLUMNS, self.ALL))
self.assertEqual(rows, [])

result = transaction.execute_update(insert_statement)
self.assertEqual(result.row_count_exact, 1)
row_count = transaction.execute_update(insert_statement)
self.assertEqual(row_count, 1)

transaction.insert(self.TABLE, self.COLUMNS, self.ROW_DATA[1:])

Expand Down
23 changes: 9 additions & 14 deletions spanner/tests/unit/test_transaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,7 @@ def test_execute_update_w_params_wo_param_types(self):
with self.assertRaises(ValueError):
transaction.execute_update(DML_QUERY_WITH_PARAM, PARAMS)

def _execute_update_helper(self, partition=None, count=0):
def _execute_update_helper(self, count=0):
from google.protobuf.struct_pb2 import Struct
from google.cloud.spanner_v1.proto.result_set_pb2 import (
ResultSet, ResultSetStats)
Expand All @@ -334,10 +334,7 @@ def _execute_update_helper(self, partition=None, count=0):
from google.cloud.spanner_v1._helpers import _make_value_pb

MODE = 2 # PROFILE
stats_pb = ResultSetStats(
query_stats=Struct(fields={
'rows_affected': _make_value_pb(1),
}))
stats_pb = ResultSetStats(row_count_exact=1)
database = _Database()
api = database.spanner_api = self._make_spanner_api()
api.execute_sql.return_value = ResultSet(stats=stats_pb)
Expand All @@ -346,11 +343,10 @@ def _execute_update_helper(self, partition=None, count=0):
transaction._transaction_id = self.TRANSACTION_ID
transaction._execute_sql_count = count

result = transaction.execute_update(
DML_QUERY_WITH_PARAM, PARAMS, PARAM_TYPES,
query_mode=MODE, partition=partition)
row_count = transaction.execute_update(
DML_QUERY_WITH_PARAM, PARAMS, PARAM_TYPES, query_mode=MODE)

self.assertEqual(result, stats_pb)
self.assertEqual(row_count, 1)

expected_transaction = TransactionSelector(id=self.TRANSACTION_ID)
expected_params = Struct(fields={
Expand All @@ -363,18 +359,17 @@ def _execute_update_helper(self, partition=None, count=0):
params=expected_params,
param_types=PARAM_TYPES,
query_mode=MODE,
partition_token=partition,
seqno=count,
metadata=[('google-cloud-resource-prefix', database.name)],
)

self.assertEqual(transaction._execute_sql_count, count + 1)

def test_execute_update_w_count_wo_partition(self):
self._execute_update_helper(count=1)
def test_execute_update_new_transaction(self):
self._execute_update_helper()

def test_execute_update_wo_count_w_partition(self):
self._execute_update_helper(partition=b'FACEDACE')
def test_execute_update_w_count(self):
self._execute_update_helper(count=1)

def test_context_mgr_success(self):
import datetime
Expand Down

0 comments on commit 08a3399

Please sign in to comment.