Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rewrite providers to use pydantic in init #219

Merged
merged 5 commits into from
Mar 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion backuper/backup_targets/mariadb.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,10 @@ def _mariadb_connection(self) -> str:
version = match.group(0)
break
if version is None: # pragma: no cover
msg = f"mariadb_connection error processing sql result, version unknown: {result}"
msg = (
f"mariadb_connection error processing sql result, "
f"version unknown: {result}"
)
log.error(msg)
raise ValueError(msg)
log.info("mariadb_connection calculated version: %s", version)
Expand Down
5 changes: 4 additions & 1 deletion backuper/backup_targets/mysql.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,10 @@ def _mysql_connection(self) -> str:
version = match.group(0)
break
if version is None: # pragma: no cover
msg = f"mysql_connection error processing sql result, version unknown: {result}"
msg = (
f"mysql_connection error processing sql result, "
f"version unknown: {result}"
)
log.error(msg)
raise ValueError(msg)
log.info("mysql_connection calculated version: %s", version)
Expand Down
10 changes: 8 additions & 2 deletions backuper/backup_targets/postgresql.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,10 @@ def _postgres_connection(self) -> str:
version = match.group(0).strip().split(" ")[1]
break
if version is None: # pragma: no cover
msg = f"postgres_connection error processing sql result, version unknown: {result}"
msg = (
"postgres_connection error processing sql result, "
"version unknown: {result}"
)
log.error(msg)
raise ValueError(msg)
log.info("postgres_connection calculated version: %s", version)
Expand All @@ -105,7 +108,10 @@ def _backup(self) -> Path:
escaped_dbname = core.safe_text_version(self.target_model.db)
name = f"{escaped_dbname}_{self.db_version}"
out_file = core.get_new_backup_path(self.env_name, name, sql=True)
shell_pg_dump_db = f"pg_dump --clean --if-exists -v -O -d {self.escaped_conn_uri} -f {out_file}"
shell_pg_dump_db = (
f"pg_dump --clean --if-exists -v -O -d "
f"{self.escaped_conn_uri} -f {out_file}"
)
log.debug("start pg_dump in subprocess: %s", shell_pg_dump_db)
core.run_subprocess(shell_pg_dump_db)
log.debug("finished pg_dump, output: %s", out_file)
Expand Down
19 changes: 19 additions & 0 deletions backuper/backup_targets/targets_mapping.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from backuper.backup_targets import (
base_target,
file,
folder,
mariadb,
mysql,
postgresql,
)
from backuper.config import BackupTargetEnum


def get_target_cls_map() -> dict[str, type[base_target.BaseBackupTarget]]:
return {
BackupTargetEnum.FILE: file.File,
BackupTargetEnum.FOLDER: folder.Folder,
BackupTargetEnum.MARIADB: mariadb.MariaDB,
BackupTargetEnum.POSTGRESQL: postgresql.PostgreSQL,
BackupTargetEnum.MYSQL: mysql.MySQL,
}
5 changes: 3 additions & 2 deletions backuper/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@

class UploadProviderEnum(StrEnum):
LOCAL_FILES_DEBUG = "debug"
GOOGLE_CLOUD_STORAGE = "gcs"
GCS = "gcs"
AWS_S3 = "aws"
AZURE = "azure"

Expand Down Expand Up @@ -81,7 +81,8 @@ def check_smtp_setup(self) -> Self:
smtp_settings = [self.SMTP_HOST, self.SMTP_FROM_ADDR, self.SMTP_TO_ADDRS]
if any(smtp_settings) != all(smtp_settings): # pragma: no cover
raise ValueError(
"parameters SMTP_HOST, SMTP_FROM_ADDR, SMTP_TO_ADDRS must be all either set or not."
"parameters SMTP_HOST, SMTP_FROM_ADDR, SMTP_TO_ADDRS "
"must be all either set or not."
)
return self

Expand Down
38 changes: 9 additions & 29 deletions backuper/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,7 @@
from pydantic import BaseModel

from backuper import config
from backuper.models.backup_target_models import (
DirectoryTargetModel,
MariaDBTargetModel,
MySQLTargetModel,
PostgreSQLTargetModel,
SingleFileTargetModel,
TargetModel,
)
from backuper.models.upload_provider_models import ProviderModel
from backuper.models import backup_target_models, models_mapping, upload_provider_models

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -152,18 +144,10 @@ def _validate_model(
return validated_target


def create_target_models() -> list[TargetModel]:
target_map: dict[str, type[TargetModel]] = {
config.BackupTargetEnum.FILE: SingleFileTargetModel,
config.BackupTargetEnum.FOLDER: DirectoryTargetModel,
config.BackupTargetEnum.MARIADB: MariaDBTargetModel,
config.BackupTargetEnum.MYSQL: MySQLTargetModel,
config.BackupTargetEnum.POSTGRESQL: PostgreSQLTargetModel,
}
def create_target_models() -> list[backup_target_models.TargetModel]:
target_map = models_mapping.get_target_map()

log.critical(target_map)

targets: list[TargetModel] = []
targets: list[backup_target_models.TargetModel] = []
for env_name, env_value in os.environ.items():
env_name_lowercase = env_name.lower()
log.debug("processing env variable %s", env_name_lowercase)
Expand All @@ -178,22 +162,18 @@ def create_target_models() -> list[TargetModel]:
return targets


def create_provider_model() -> ProviderModel:
target_map: dict[config.UploadProviderEnum, type[ProviderModel]] = {}
for target_model in ProviderModel.__subclasses__():
name = config.UploadProviderEnum(
target_model.__name__.lower().removesuffix("providermodel")
)
target_map[name] = target_model
def create_provider_model() -> upload_provider_models.ProviderModel:
provider_map = models_mapping.get_provider_map()

log.info("start validating BACKUP_PROVIDER environment variable")

base_provider = _validate_model(
"backup_provider",
config.options.BACKUP_PROVIDER,
ProviderModel,
upload_provider_models.ProviderModel,
value_whitespace_split=True,
)
target_model_cls = target_map[base_provider.name]
target_model_cls = provider_map[base_provider.name]
return _validate_model(
"backup_provider", config.options.BACKUP_PROVIDER, target_model_cls
)
Expand Down
53 changes: 25 additions & 28 deletions backuper/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,18 @@
from typing import NoReturn

from backuper import config, core
from backuper.backup_targets.base_target import BaseBackupTarget
from backuper.backup_targets.file import File
from backuper.backup_targets.folder import Folder
from backuper.backup_targets.mariadb import MariaDB
from backuper.backup_targets.mysql import MySQL
from backuper.backup_targets.postgresql import PostgreSQL
from backuper.backup_targets import (
base_target,
targets_mapping,
)
from backuper.notifications.notifications_context import (
PROGRAM_STEP,
NotificationsContext,
)
from backuper.upload_providers import BaseUploadProvider
from backuper.upload_providers import (
base_provider,
providers_mapping,
)

exit_event = threading.Event()
log = logging.getLogger(__name__)
Expand All @@ -32,19 +33,18 @@ def quit(sig: int, frame: FrameType | None) -> None:


@NotificationsContext(step_name=PROGRAM_STEP.SETUP_PROVIDER)
def backup_provider() -> BaseUploadProvider:
backup_provider_map: dict[config.UploadProviderEnum, type[BaseUploadProvider]] = {}
for backup_provider in BaseUploadProvider.__subclasses__():
backup_provider_map[backup_provider.target_name] = backup_provider # type: ignore
def backup_provider() -> base_provider.BaseUploadProvider:
provider_cls_map = providers_mapping.get_provider_cls_map()

provider_model = core.create_provider_model()
log.info(
"initializing provider: `%s`",
provider_model.name,
)
provider_target_cls = backup_provider_map[provider_model.name]

provider_target_cls = provider_cls_map[provider_model.name]
log.debug("initializing %s with %s", provider_target_cls, provider_model)
res_backup_provider = provider_target_cls(**provider_model.model_dump())
res_backup_provider = provider_target_cls(target_provider=provider_model)
log.info(
"success initializing provider: `%s`",
provider_model.name,
Expand All @@ -53,16 +53,10 @@ def backup_provider() -> BaseUploadProvider:


@NotificationsContext(step_name=PROGRAM_STEP.SETUP_TARGETS)
def backup_targets() -> list[BaseBackupTarget]:
backup_targets_map: dict[str, type[BaseBackupTarget]] = {
config.BackupTargetEnum.FILE: File,
config.BackupTargetEnum.FOLDER: Folder,
config.BackupTargetEnum.MARIADB: MariaDB,
config.BackupTargetEnum.POSTGRESQL: PostgreSQL,
config.BackupTargetEnum.MYSQL: MySQL,
}

backup_targets: list[BaseBackupTarget] = []
def backup_targets() -> list[base_target.BaseBackupTarget]:
backup_target_cls_map = targets_mapping.get_target_cls_map()

backup_targets: list[base_target.BaseBackupTarget] = []
target_models = core.create_target_models()
if not target_models:
raise RuntimeError("Found 0 backup targets, at least 1 is required.")
Expand All @@ -74,7 +68,7 @@ def backup_targets() -> list[BaseBackupTarget]:
"initializing target: `%s`",
target_model.env_name,
)
backup_target_cls = backup_targets_map[target_model.name]
backup_target_cls = backup_target_cls_map[target_model.name]
log.debug("initializing %s with %s", backup_target_cls, target_model)
backup_targets.append(backup_target_cls(target_model=target_model))
log.info(
Expand Down Expand Up @@ -121,14 +115,17 @@ def shutdown() -> NoReturn: # pragma: no cover
sys.exit(0)
else:
log.warning(
"noooo, exiting! i am now killing myself with %d daemon threads force killed. "
"you can extend this time using environment SIGTERM_TIMEOUT_SECS.",
"noooo, exiting! i am now killing myself with %d daemon threads "
"force killed. you can extend this time using environment "
"SIGTERM_TIMEOUT_SECS.",
threading.active_count() - 1,
)
sys.exit(1)


def run_backup(target: BaseBackupTarget, provider: BaseUploadProvider) -> None:
def run_backup(
target: base_target.BaseBackupTarget, provider: base_provider.BaseUploadProvider
) -> None:
log.info("start making backup of target: `%s`", target.env_name)
with NotificationsContext(
step_name=PROGRAM_STEP.BACKUP_CREATE, env_name=target.env_name
Expand All @@ -137,7 +134,7 @@ def run_backup(target: BaseBackupTarget, provider: BaseUploadProvider) -> None:
log.info(
"backup file created: %s, starting post save upload to provider %s",
backup_file,
provider.target_name,
provider.__class__.__name__,
)
with NotificationsContext(
step_name=PROGRAM_STEP.UPLOAD,
Expand Down
21 changes: 21 additions & 0 deletions backuper/models/models_mapping.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from backuper.config import BackupTargetEnum, UploadProviderEnum
from backuper.models import backup_target_models, upload_provider_models


def get_target_map() -> dict[str, type[backup_target_models.TargetModel]]:
return {
BackupTargetEnum.FILE: backup_target_models.SingleFileTargetModel,
BackupTargetEnum.FOLDER: backup_target_models.DirectoryTargetModel,
BackupTargetEnum.MARIADB: backup_target_models.MariaDBTargetModel,
BackupTargetEnum.MYSQL: backup_target_models.MySQLTargetModel,
BackupTargetEnum.POSTGRESQL: backup_target_models.PostgreSQLTargetModel,
}


def get_provider_map() -> dict[str, type[upload_provider_models.ProviderModel]]:
return {
UploadProviderEnum.AZURE: upload_provider_models.AzureProviderModel,
UploadProviderEnum.LOCAL_FILES_DEBUG: upload_provider_models.DebugProviderModel,
UploadProviderEnum.GCS: upload_provider_models.GCSProviderModel,
UploadProviderEnum.AWS_S3: upload_provider_models.AWSProviderModel,
}
7 changes: 5 additions & 2 deletions backuper/models/upload_provider_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,15 @@


class ProviderModel(BaseModel):
name: config.UploadProviderEnum
name: str


class DebugProviderModel(ProviderModel):
pass
name: str = config.UploadProviderEnum.LOCAL_FILES_DEBUG


class GCSProviderModel(ProviderModel):
name: str = config.UploadProviderEnum.GCS
bucket_name: str
bucket_upload_path: str
service_account_base64: SecretStr
Expand All @@ -29,6 +30,7 @@ def process_service_account_base64(


class AWSProviderModel(ProviderModel):
name: str = config.UploadProviderEnum.AWS_S3
bucket_name: str
bucket_upload_path: str
key_id: str
Expand All @@ -38,5 +40,6 @@ class AWSProviderModel(ProviderModel):


class AzureProviderModel(ProviderModel):
name: str = config.UploadProviderEnum.AZURE
container_name: str
connect_string: SecretStr
12 changes: 0 additions & 12 deletions backuper/upload_providers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,13 +1 @@
from .aws_s3 import UploadProviderAWS
from .azure import UploadProviderAzure
from .base_provider import BaseUploadProvider
from .debug import UploadProviderLocalDebug
from .google_cloud_storage import UploadProviderGCS

__all__ = [
"BaseUploadProvider",
"UploadProviderAWS",
"UploadProviderGCS",
"UploadProviderLocalDebug",
"UploadProviderAzure",
]
32 changes: 10 additions & 22 deletions backuper/upload_providers/aws_s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@

import boto3
from boto3.s3.transfer import TransferConfig
from pydantic import SecretStr

from backuper import config, core
from backuper import core
from backuper.models.upload_provider_models import AWSProviderModel
from backuper.upload_providers.base_provider import BaseUploadProvider

log = logging.getLogger(__name__)
Expand All @@ -16,33 +16,21 @@ class DeleteItemDict(TypedDict):
Key: str


class UploadProviderAWS(
BaseUploadProvider,
name=config.UploadProviderEnum.AWS_S3,
):
class UploadProviderAWS(BaseUploadProvider):
"""AWS S3 bucket for storing backups"""

def __init__(
self,
bucket_name: str,
bucket_upload_path: str,
key_id: str,
key_secret: SecretStr,
region: str,
max_bandwidth: int | None,
**kwargs: str,
) -> None:
self.bucket_upload_path = bucket_upload_path
self.max_bandwidth = max_bandwidth
def __init__(self, target_provider: AWSProviderModel) -> None:
self.bucket_upload_path = target_provider.bucket_upload_path
self.max_bandwidth = target_provider.max_bandwidth

s3: Any = boto3.resource(
"s3",
region_name=region,
aws_access_key_id=key_id,
aws_secret_access_key=key_secret.get_secret_value(),
region_name=target_provider.region,
aws_access_key_id=target_provider.key_id,
aws_secret_access_key=target_provider.key_secret.get_secret_value(),
)

self.bucket = s3.Bucket(bucket_name)
self.bucket = s3.Bucket(target_provider.bucket_name)
self.transfer_config = TransferConfig(max_bandwidth=self.max_bandwidth)

def _post_save(self, backup_file: Path) -> str:
Expand Down
Loading