Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add savepoint API #433

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions edgedb/asyncio_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import socket
import ssl
import typing
import uuid

from . import abstract
from . import base_client
Expand Down Expand Up @@ -322,6 +323,10 @@ def _exclusive(self):
finally:
self._locked = False

async def savepoint(self) -> transaction.Savepoint:
name = "s" + uuid.uuid4().hex
return await self._declare_savepoint(name)


class AsyncIORetry(transaction.BaseRetry):

Expand Down
15 changes: 15 additions & 0 deletions edgedb/blocking_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import threading
import time
import typing
import uuid

from . import abstract
from . import base_client
Expand Down Expand Up @@ -270,6 +271,14 @@ async def close(self, timeout=None):
self._closing = False


class Savepoint(transaction.Savepoint):
def release(self):
self._tx._client._iter_coroutine(super().release())

def rollback(self):
self._tx._client._iter_coroutine(super().rollback())


class Iteration(transaction.BaseTransaction, abstract.Executor):

__slots__ = ("_managed", "_lock")
Expand Down Expand Up @@ -320,6 +329,12 @@ def _exclusive(self):
finally:
self._lock.release()

def savepoint(self) -> Savepoint:
name = "s" + uuid.uuid4().hex
return self._client._iter_coroutine(
self._declare_savepoint(name, cls=Savepoint)
)


class Retry(transaction.BaseRetry):

Expand Down
51 changes: 51 additions & 0 deletions edgedb/transaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
#


from __future__ import annotations

import enum

from . import abstract
Expand All @@ -32,12 +34,47 @@ class TransactionState(enum.Enum):
FAILED = 4


class Savepoint:
__slots__ = ('_name', '_tx', '_active')

def __init__(self, name: str, transaction: BaseTransaction):
self._name = name
self._tx = transaction
self._active = True

@property
def active(self):
return self._active

def _ensure_active(self):
if not self._active:
raise errors.InterfaceError(
f"savepoint {self._name!r} is no longer active"
)

async def release(self):
self._ensure_active()
await self._tx._privileged_execute(f"release savepoint {self._name}")
del self._tx._savepoints[self._name]
self._active = False

async def rollback(self):
self._ensure_active()
await self._tx._privileged_execute(
f"rollback to savepoint {self._name}"
)
names = list(self._tx._savepoints)
for name in names[names.index(self._name):]:
self._tx._savepoints.pop(name)._active = False


class BaseTransaction:

__slots__ = (
'_client',
'_connection',
'_options',
'_savepoints',
'_state',
'__retry',
'__iteration',
Expand All @@ -48,6 +85,7 @@ def __init__(self, retry, client, iteration):
self._client = client
self._connection = None
self._options = retry._options.transaction_options
self._savepoints = {}
self._state = TransactionState.NEW
self.__retry = retry
self.__iteration = iteration
Expand Down Expand Up @@ -128,6 +166,9 @@ async def _exit(self, extype, ex):
if not self.__started:
return False

for sp in self._savepoints.values():
sp._active = False

try:
if extype is None:
query = self._make_commit_query()
Expand Down Expand Up @@ -200,6 +241,16 @@ async def _privileged_execute(self, query: str) -> None:
state=self._get_state(),
))

async def _declare_savepoint(self, savepoint: str, cls=Savepoint):
if savepoint in self._savepoints:
raise errors.InterfaceError(
f"savepoint {savepoint!r} already exists"
)
await self._ensure_transaction()
await self._privileged_execute(f"declare savepoint {savepoint}")
self._savepoints[savepoint] = rv = cls(savepoint, self)
return rv


class BaseRetry:

Expand Down
51 changes: 51 additions & 0 deletions tests/test_async_tx.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,10 @@ class TestAsyncTx(tb.AsyncQueryTestCase):
};
'''

TEARDOWN_METHOD = '''
DELETE test::TransactionTest;
'''

TEARDOWN = '''
DROP TYPE test::TransactionTest;
'''
Expand Down Expand Up @@ -104,3 +108,50 @@ async def test_async_transaction_exclusive(self):
):
await asyncio.wait_for(f1, timeout=5)
await asyncio.wait_for(f2, timeout=5)

async def test_async_transaction_savepoint_1(self):
async for tx in self.client.transaction():
async with tx:
sp1 = await tx.savepoint()
sp2 = await tx.savepoint()
await tx.execute('''
INSERT test::TransactionTest { name := '1' };
''')
await sp2.release()
with self.assertRaisesRegex(
edgedb.InterfaceError, "savepoint.*is no longer active"
):
await sp2.release()
await sp1.release()

result = await self.client.query('SELECT test::TransactionTest.name')

self.assertEqual(result, ["1"])

async def test_async_transaction_savepoint_2(self):
async for tx in self.client.transaction():
async with tx:
await tx.execute('''
INSERT test::TransactionTest { name := '1' };
''')
sp1 = await tx.savepoint()
await tx.execute('''
INSERT test::TransactionTest { name := '2' };
''')
sp2 = await tx.savepoint()
await tx.execute('''
INSERT test::TransactionTest { name := '3' };
''')
await sp1.rollback()
with self.assertRaisesRegex(
edgedb.InterfaceError, "savepoint.*is no longer active"
):
await sp1.rollback()
with self.assertRaisesRegex(
edgedb.InterfaceError, "savepoint.*is no longer active"
):
await sp2.rollback()

result = await self.client.query('SELECT test::TransactionTest.name')

self.assertEqual(result, ["1"])
51 changes: 51 additions & 0 deletions tests/test_sync_tx.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,10 @@ class TestSyncTx(tb.SyncQueryTestCase):
};
'''

TEARDOWN_METHOD = '''
DELETE test::TransactionTest;
'''

TEARDOWN = '''
DROP TYPE test::TransactionTest;
'''
Expand Down Expand Up @@ -113,3 +117,50 @@ def test_sync_transaction_exclusive(self):
):
f1.result(timeout=5)
f2.result(timeout=5)

def test_sync_transaction_savepoint_1(self):
for tx in self.client.transaction():
with tx:
sp1 = tx.savepoint()
sp2 = tx.savepoint()
tx.execute('''
INSERT test::TransactionTest { name := '1' };
''')
sp2.release()
with self.assertRaisesRegex(
edgedb.InterfaceError, "savepoint.*is no longer active"
):
sp2.release()
sp1.release()

result = self.client.query('SELECT test::TransactionTest.name')

self.assertEqual(result, ["1"])

def test_sync_transaction_savepoint_2(self):
for tx in self.client.transaction():
with tx:
tx.execute('''
INSERT test::TransactionTest { name := '1' };
''')
sp1 = tx.savepoint()
tx.execute('''
INSERT test::TransactionTest { name := '2' };
''')
sp2 = tx.savepoint()
tx.execute('''
INSERT test::TransactionTest { name := '3' };
''')
sp1.rollback()
with self.assertRaisesRegex(
edgedb.InterfaceError, "savepoint.*is no longer active"
):
sp1.rollback()
with self.assertRaisesRegex(
edgedb.InterfaceError, "savepoint.*is no longer active"
):
sp2.rollback()

result = self.client.query('SELECT test::TransactionTest.name')

self.assertEqual(result, ["1"])