diff --git a/gce_rescue/tasks/actions.py b/gce_rescue/tasks/actions.py index 9903f96..7d3c453 100644 --- a/gce_rescue/tasks/actions.py +++ b/gce_rescue/tasks/actions.py @@ -50,12 +50,6 @@ def _list_tasks(vm: Instance, action: str) -> List: 'vm': vm }] }, - { - 'name': take_snapshot, - 'args': [{ - 'vm': vm - }] - }, { 'name': create_rescue_disk, 'args': [{ @@ -127,9 +121,13 @@ def _list_tasks(vm: Instance, action: str) -> List: def call_tasks(vm: Instance, action: str) -> None: """ Loop tasks dict and execute """ tasks = _list_tasks(vm = vm, action = action) - if get_config('skip-snapshot'): - _logger.info(f'Skipping snapshot backup.') - tasks = [task for task in tasks if task['name'].__name__ != 'take_snapshot'] + async_backup_thread = None + if action == 'set_rescue_mode': + if get_config('skip-snapshot'): + _logger.info(f'Skipping snapshot backup.') + else: + take_snapshot(vm) + async_backup_thread = True total_tasks = len(tasks) tracker = Tracker(total_tasks) @@ -142,4 +140,8 @@ def call_tasks(vm: Instance, action: str) -> None: execute(**args) tracker.advance(step = 1) - tracker.finish() \ No newline at end of file + if async_backup_thread: + _logger.info(f'Waiting for async backup to finish') + take_snapshot(vm, join_snapshot=True) + _logger.info('done.') + tracker.finish() diff --git a/gce_rescue/tasks/disks.py b/gce_rescue/tasks/disks.py index c277cfe..657fc9c 100644 --- a/gce_rescue/tasks/disks.py +++ b/gce_rescue/tasks/disks.py @@ -16,6 +16,7 @@ from typing import Dict import logging +from threading import Thread import googleapiclient.errors @@ -25,6 +26,7 @@ from googleapiclient.errors import HttpError _logger = logging.getLogger(__name__) +snapshot_thread = None def _create_rescue_disk(vm, source_disk: str) -> Dict: """ Create new temporary rescue disk based on source_disk. @@ -178,8 +180,13 @@ def _detach_disk(vm, disk: str) -> Dict: return result -def take_snapshot(vm) -> None: - create_snapshot(vm) +def take_snapshot(vm, join_snapshot=None) -> None: + global snapshot_thread + if not join_snapshot: + snapshot_thread = Thread(target=create_snapshot, args=(vm,), daemon=True) + snapshot_thread.start() + else: + snapshot_thread.join() def create_rescue_disk(vm) -> None: