Skip to content

Commit

Permalink
task wrapper decorator
Browse files Browse the repository at this point in the history
  • Loading branch information
tomerlf1 committed Jan 31, 2024
1 parent c3c5abd commit 51c8eaf
Show file tree
Hide file tree
Showing 7 changed files with 51 additions and 5 deletions.
7 changes: 7 additions & 0 deletions gce_rescue/gce.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

""" Initilization Instance() with VM information. """
import sys
import logging

from googleapiclient.discovery import Resource
from googleapiclient.errors import HttpError
Expand All @@ -26,6 +27,7 @@
from gce_rescue.tasks.pre_validations import Validations
from gce_rescue.config import get_config

_logger = logging.getLogger(__name__)

def get_instance_info(
compute: Resource,
Expand All @@ -50,6 +52,9 @@ def guess_guest(data: Dict) -> str:
Default: projects/debian-cloud/global/images/family/debian-11"""

guests = get_config('source_guests')
if not data.get('disks'):
_logger.error(f'Unable to get disks for vm. Check whether a boot disk is attached to your vm.')
raise Exception("No boot disk was found for vm")
for disk in data['disks']:
if disk['boot']:
if 'architecture' in disk:
Expand All @@ -58,6 +63,8 @@ def guess_guest(data: Dict) -> str:
arch = 'x86_64'
guest_default = guests[arch][0]
guest_name = guest_default.split('/')[-1]
if not disk.get('licenses'):
return guest_default
for lic in disk['licenses']:
if guest_name in lic:
guest_default = guests[arch][1]
Expand Down
2 changes: 1 addition & 1 deletion gce_rescue/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from gce_rescue.gce import Instance

def tip_connect_ssh(vm: Instance) -> str:
return (f'└── Your instance is READY! You can now connect your instance '
return (f'└── Your instance is READY! You can now connect to your instance '
f' {vm.name} via:\n 1. CLI. (add --tunnel-through-iap if necessary)\n'
f' $ gcloud compute ssh {vm.name} --zone={vm.zone} '
f'--project={vm.project} --ssh-flag="-o StrictHostKeyChecking=no"\n OR\n'
Expand Down
4 changes: 2 additions & 2 deletions gce_rescue/tasks/actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,11 +133,11 @@ def call_tasks(vm: Instance, action: str) -> None:
tracker = Tracker(total_tasks)
tracker.start()

for task in tasks:
for task_index, task in enumerate(tasks, 1):
execute = task['name']
args = task['args'][0]

execute(**args)
execute(**args, task_index=task_index, total_tasks=total_tasks)
tracker.advance(step = 1)

if async_backup_thread:
Expand Down
6 changes: 5 additions & 1 deletion gce_rescue/tasks/disks.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

from gce_rescue.tasks.keeper import wait_for_operation
from gce_rescue.tasks.backup import create_snapshot
from gce_rescue.utils import ThreadHandler as Handler
from gce_rescue.utils import ThreadHandler as Handler, tasks_wrapper
from googleapiclient.errors import HttpError

_logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -130,6 +130,7 @@ def list_disk(vm, project_data: Dict, label_filter: str) -> Dict:
return result['items']


@tasks_wrapper
def attach_disk(
vm,
disk_name: str,
Expand Down Expand Up @@ -189,6 +190,7 @@ def take_snapshot(vm, join_snapshot=None) -> None:
snapshot_thread.join()


@tasks_wrapper
def create_rescue_disk(vm) -> None:
device_name = vm.disks['device_name']
# task1 = multitasks.Handler(
Expand Down Expand Up @@ -222,6 +224,8 @@ def list_snapshot(vm) -> str:
return ''
return snapshot_name


@tasks_wrapper
def restore_original_disk(vm) -> None:
""" Restore tasks to the original disk """
device_name = vm.disks['device_name']
Expand Down
3 changes: 3 additions & 0 deletions gce_rescue/tasks/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,13 @@

from gce_rescue.config import get_config
from gce_rescue.tasks.keeper import wait_for_operation, wait_for_os_boot
from gce_rescue.utils import tasks_wrapper
from typing import Dict
import logging

_logger = logging.getLogger(__name__)

@tasks_wrapper
def set_metadata(vm) -> Dict:
"""Configure Instance custom metadata.
https://cloud.google.com/compute/docs/reference/rest/v1/instances/setMetadata
Expand Down Expand Up @@ -53,6 +55,7 @@ def set_metadata(vm) -> Dict:
return result


@tasks_wrapper
def restore_metadata_items(vm, remove_rescue_mode: bool = False) -> Dict:
"""Restore original metadata.items after the instance is running again."""

Expand Down
4 changes: 4 additions & 0 deletions gce_rescue/tasks/operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,13 @@

from gce_rescue.gce import Instance
from gce_rescue.tasks.keeper import wait_for_operation
from gce_rescue.utils import tasks_wrapper
import logging

_logger = logging.getLogger(__name__)


@tasks_wrapper
def start_instance(vm: Instance) -> str:
"""Start instance."""

Expand All @@ -38,6 +41,7 @@ def start_instance(vm: Instance) -> str:
return vm.status


@tasks_wrapper
def stop_instance(vm: Instance) -> str:
"""Stop instance."""
_logger.info(f'Stopping {vm.name}...')
Expand Down
30 changes: 29 additions & 1 deletion gce_rescue/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@
""" List of classes and functions to be used across the code. """

from time import sleep
from datetime import datetime

from googleapiclient.errors import HttpError

import logging
import multiprocessing
from threading import Thread
Expand Down Expand Up @@ -140,4 +144,28 @@ def read_input(msg: str) -> None:
if input_answer.upper() != 'Y':
print(f'got input: "{input_answer}". Aborting')
sys.exit(1)



def tasks_wrapper(task_func):
def inner(*args, **kwargs):
_logger.info(f"task {task_func.__name__} "
f"started at: {datetime.now().strftime('%H:%M:%S')}")
task_index = kwargs.pop('task_index') if kwargs.get('task_index') else None
total_tasks = kwargs.pop('total_tasks') if kwargs.get('total_tasks') \
else None
try:
res = task_func(*args, **kwargs)
except HttpError as e:
_logger.error(f'HttpError caught on task {task_func.__name__} '
f'with error: {e}')
return
_logger.info(f"task {task_func.__name__} ended at: "
f"{datetime.now().strftime('%H:%M:%S')}")
if task_index and total_tasks:
_logger.info(f'Progress: {task_index}/{total_tasks} tasks completed')
print(f'finished {task_func.__name__} {task_index}/{total_tasks}'
f' tasks completed')

return res

return inner

0 comments on commit 51c8eaf

Please sign in to comment.