Skip to content

Commit

Permalink
Add 'BatchTransaction' wrapper class (#438)
Browse files Browse the repository at this point in the history
Encapsulates session ID / transaction ID, to be marshalled across
the wire to another process / host for performing partitioned
reads / queries.
  • Loading branch information
tseaver committed Feb 26, 2018
1 parent 6b62951 commit 96df657
Show file tree
Hide file tree
Showing 5 changed files with 1,064 additions and 30 deletions.
266 changes: 266 additions & 0 deletions spanner/google/cloud/spanner_v1/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from google.cloud.spanner_v1._helpers import _metadata_with_prefix
from google.cloud.spanner_v1.batch import Batch
from google.cloud.spanner_v1.gapic.spanner_client import SpannerClient
from google.cloud.spanner_v1.keyset import KeySet
from google.cloud.spanner_v1.pool import BurstyPool
from google.cloud.spanner_v1.pool import SessionCheckout
from google.cloud.spanner_v1.session import Session
Expand Down Expand Up @@ -308,6 +309,14 @@ def batch(self):
"""
return BatchCheckout(self)

def batch_transaction(self):
"""Return an object which wraps a batch read / query.
:rtype: :class:`~google.cloud.spanner_v1.database.BatchTransaction`
:returns: new wrapper
"""
return BatchTransaction(self)

def run_in_transaction(self, func, *args, **kw):
"""Perform a unit of work in a transaction, retrying on abort.
Expand Down Expand Up @@ -406,6 +415,263 @@ def __exit__(self, exc_type, exc_val, exc_tb):
self._database._pool.put(self._session)


class BatchTransaction(object):
"""Wrapper for generating and processing read / query batches.
:type database: :class:`~google.cloud.spannder.database.Database`
:param database: database to use
:type read_timestamp: :class:`datetime.datetime`
:param read_timestamp: Execute all reads at the given timestamp.
:type min_read_timestamp: :class:`datetime.datetime`
:param min_read_timestamp: Execute all reads at a
timestamp >= ``min_read_timestamp``.
:type max_staleness: :class:`datetime.timedelta`
:param max_staleness: Read data at a
timestamp >= NOW - ``max_staleness`` seconds.
:type exact_staleness: :class:`datetime.timedelta`
:param exact_staleness: Execute all reads at a timestamp that is
``exact_staleness`` old.
"""
def __init__(
self, database,
read_timestamp=None,
min_read_timestamp=None,
max_staleness=None,
exact_staleness=None):

self._database = database
self._session = None
self._snapshot = None
self._read_timestamp = read_timestamp
self._min_read_timestamp = min_read_timestamp
self._max_staleness = max_staleness
self._exact_staleness = exact_staleness

@classmethod
def from_dict(cls, database, mapping):
"""Reconstruct an instance from a mapping.
:type database: :class:`~google.cloud.spannder.database.Database`
:param database: database to use
:type mapping: mapping
:param mapping: serialized state of the instance
:rtype: :class:`BatchTransaction`
"""
instance = cls(database)
session = instance._session = database.session()
session._session_id = mapping['session_id']
txn = session.transaction()
txn._transaction_id = mapping['transaction_id']
return instance

def to_dict(self):
"""Return state as a dictionary.
Result can be used to serialize the instance and reconstitute
it later using :meth:`from_dict`.
:rtype: dict
"""
session = self._get_session()
return {
'session_id': session._session_id,
'transaction_id': session._transaction._transaction_id,
}

def _get_session(self):
"""Create session as needed.
.. note::
Caller is responsible for cleaning up the session after
all partitions have been processed.
"""
if self._session is None:
session = self._session = self._database.session()
session.create()
txn = session.transaction()
txn.begin()
return self._session

def _get_snapshot(self):
"""Create snapshot if needed."""
if self._snapshot is None:
self._snapshot = self._get_session().snapshot(
read_timestamp=self._read_timestamp,
min_read_timestamp=self._min_read_timestamp,
max_staleness=self._max_staleness,
exact_staleness=self._exact_staleness,
multi_use=True)
return self._snapshot

def generate_read_batches(
self, table, columns, keyset,
index='', partition_size_bytes=None, max_partitions=None):
"""Start a partitioned batch read operation.
Uses the ``PartitionRead`` API request to initiate the partitioned
read. Returns a list of batch information needed to perform the
actual reads.
:type table: str
:param table: name of the table from which to fetch data
:type columns: list of str
:param columns: names of columns to be retrieved
:type keyset: :class:`~google.cloud.spanner_v1.keyset.KeySet`
:param keyset: keys / ranges identifying rows to be retrieved
:type index: str
:param index: (Optional) name of index to use, rather than the
table's primary key
:type partition_size_bytes: int
:param partition_size_bytes:
(Optional) desired size for each partition generated. The service
uses this as a hint, the actual partition size may differ.
:type max_partitions: int
:param max_partitions:
(Optional) desired maximum number of partitions generated. The
service uses this as a hint, the actual number of partitions may
differ.
:rtype: iterable of dict
:returns:
mappings of information used peform actual partitioned reads via
:meth:`process_read_batch`.
"""
partitions = self._get_snapshot().partition_read(
table=table, columns=columns, keyset=keyset, index=index,
partition_size_bytes=partition_size_bytes,
max_partitions=max_partitions)

read_info = {
'table': table,
'columns': columns,
'keyset': keyset._to_dict(),
'index': index,
}
for partition in partitions:
yield {'partition': partition, 'read': read_info.copy()}

def process_read_batch(self, batch):
"""Process a single, partitioned read.
:type batch: mapping
:param batch:
one of the mappings returned from an earlier call to
:meth:`generate_read_batches`.
:rtype: :class:`~google.cloud.spanner_v1.streamed.StreamedResultSet`
:returns: a result set instance which can be used to consume rows.
"""
kwargs = batch['read']
keyset_dict = kwargs.pop('keyset')
kwargs['keyset'] = KeySet._from_dict(keyset_dict)
return self._get_snapshot().read(
partition=batch['partition'], **kwargs)

def generate_query_batches(
self, sql, params=None, param_types=None,
partition_size_bytes=None, max_partitions=None):
"""Start a partitioned query operation.
Uses the ``PartitionQuery`` API request to start a partitioned
query operation. Returns a list of batch information needed to
peform the actual queries.
:type sql: str
:param sql: SQL query statement
:type params: dict, {str -> column value}
:param params: values for parameter replacement. Keys must match
the names used in ``sql``.
:type param_types: dict[str -> Union[dict, .types.Type]]
:param param_types:
(Optional) maps explicit types for one or more param values;
required if parameters are passed.
:type partition_size_bytes: int
:param partition_size_bytes:
(Optional) desired size for each partition generated. The service
uses this as a hint, the actual partition size may differ.
:type partition_size_bytes: int
:param partition_size_bytes:
(Optional) desired size for each partition generated. The service
uses this as a hint, the actual partition size may differ.
:type max_partitions: int
:param max_partitions:
(Optional) desired maximum number of partitions generated. The
service uses this as a hint, the actual number of partitions may
differ.
:rtype: iterable of dict
:returns:
mappings of information used peform actual partitioned reads via
:meth:`process_read_batch`.
"""
partitions = self._get_snapshot().partition_query(
sql=sql, params=params, param_types=param_types,
partition_size_bytes=partition_size_bytes,
max_partitions=max_partitions)

query_info = {'sql': sql}
if params:
query_info['params'] = params
query_info['param_types'] = param_types

for partition in partitions:
yield {'partition': partition, 'query': query_info}

def process_query_batch(self, batch):
"""Process a single, partitioned query.
:type batch: mapping
:param batch:
one of the mappings returned from an earlier call to
:meth:`generate_query_batches`.
:rtype: :class:`~google.cloud.spanner_v1.streamed.StreamedResultSet`
:returns: a result set instance which can be used to consume rows.
"""
return self._get_snapshot().execute_sql(
partition=batch['partition'], **batch['query'])

def process(self, batch):
"""Process a single, partitioned query or read.
:type batch: mapping
:param batch:
one of the mappings returned from an earlier call to
:meth:`generate_query_batches`.
:rtype: :class:`~google.cloud.spanner_v1.streamed.StreamedResultSet`
:returns: a result set instance which can be used to consume rows.
:raises ValueError: if batch does not contain either 'read' or 'query'
"""
if 'query' in batch:
return self.process_query_batch(batch)
if 'read' in batch:
return self.process_read_batch(batch)
raise ValueError("Invalid batch")

def close(self):
"""Clean up underlying session."""
if self._session is not None:
self._session.delete()


def _check_ddl_statements(value):
"""Validate DDL Statements used to define database schema.
Expand Down
67 changes: 67 additions & 0 deletions spanner/google/cloud/spanner_v1/keyset.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,35 @@ def _to_pb(self):

return KeyRangePB(**kwargs)

def _to_dict(self):
"""Return keyrange's state as a dict.
:rtype: dict
:returns: state of this instance.
"""
mapping = {}

if self.start_open:
mapping['start_open'] = self.start_open

if self.start_closed:
mapping['start_closed'] = self.start_closed

if self.end_open:
mapping['end_open'] = self.end_open

if self.end_closed:
mapping['end_closed'] = self.end_closed

return mapping

def __eq__(self, other):
"""Compare by serialized state."""
if not isinstance(other, self.__class__):
return NotImplemented
return self._to_dict() == other._to_dict()



class KeySet(object):
"""Identify table rows via keys / ranges.
Expand Down Expand Up @@ -122,3 +151,41 @@ def _to_pb(self):
kwargs['ranges'] = [krange._to_pb() for krange in self.ranges]

return KeySetPB(**kwargs)

def _to_dict(self):
"""Return keyset's state as a dict.
The result can be used to serialize the instance and reconstitute
it later using :meth:`_from_dict`.
:rtype: dict
:returns: state of this instance.
"""
if self.all_:
return {'all': True}

return {
'keys': self.keys,
'ranges': [keyrange._to_dict() for keyrange in self.ranges],
}

def __eq__(self, other):
"""Compare by serialized state."""
if not isinstance(other, self.__class__):
return NotImplemented
return self._to_dict() == other._to_dict()

@classmethod
def _from_dict(cls, mapping):
"""Create an instance from the corresponding state mapping.
:type mapping: dict
:param mapping: the instance state.
"""
if mapping.get('all'):
return cls(all_=True)

r_mappings = mapping.get('ranges', ())
ranges = [KeyRange(**r_mapping) for r_mapping in r_mappings]

return cls(keys=mapping.get('keys', ()), ranges=ranges)
Loading

0 comments on commit 96df657

Please sign in to comment.