Skip to content

Commit

Permalink
feat: bootstrap task (#233)
Browse files Browse the repository at this point in the history
* setup arg parser and skeleton of bootstrap task

* implement barebone all dbt models dict building

* iterate through models and get col list and schema path

* combine getting columns and finding descriptor

* update tests for bootstrap_task

* update tests for audit_task

* add model placeholders in bootstrap_task

add all models placeholders

* refactor dbt models info in a dataclass for cleaner type checking

* feat: bootstrap task (Sourcery refactored) (#234)

* 'Refactored by Sourcery'

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Co-authored-by: Sourcery AI <>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* fix typo in new fragment

Co-authored-by: sourcery-ai[bot] <58596630+sourcery-ai[bot]@users.noreply.github.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
3 people committed May 3, 2021
1 parent 6a2d381 commit 9d841e8
Show file tree
Hide file tree
Showing 17 changed files with 419 additions and 54 deletions.
3 changes: 3 additions & 0 deletions changelog/233.feature.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
dbt-sugar can now automatically generate bootstrap model descriptor files (schema.yml) for all your models as well as their columns. Bootstrap model descriptors will contain either placeholders for undocumented columns and model descrptions, unless columns are documented in other models and in that case `bootstrap` will populate those columns with their definitions.

You can generate bootstraps by calling `dbt-sugar bootstrap`. Running `bootstrap` is particularly useful when you want to run an **exhaustive** `audit` on all your model since the `audit` task does not, by itself, check your models agains the database to make it less resource hungry. A follow up code-change will introduce an `--exhaustive` option on the `audit` task which will call `bootstrap` first and run `audit` after.
6 changes: 3 additions & 3 deletions dbt_sugar/core/connectors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
Only use this class implemented by a child connector.
"""
from abc import ABC
from typing import Any, Dict, List, Optional, Tuple
from typing import Dict, Sequence

import sqlalchemy

Expand All @@ -31,7 +31,7 @@ def __init__(

def get_columns_from_table(
self, target_table: str, target_schema: str, use_describe: bool = False
) -> Optional[List[Tuple[Any]]]:
) -> Sequence[str]:
"""
Method that creates cursor to run a query.
Expand All @@ -40,7 +40,7 @@ def get_columns_from_table(
target_schema (str): schema to get the table from.
Returns:
Optional[List[Tuple[Any]]]: With the results of the query.
Optional[Sequence[str]]: With the results of the query.
"""
inspector = sqlalchemy.engine.reflection.Inspector.from_engine(self.engine)
columns = inspector.get_columns(target_table, target_schema)
Expand Down
4 changes: 2 additions & 2 deletions dbt_sugar/core/connectors/snowflake_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
Module dependent of the base connector.
"""
from typing import Any, Dict, List, Optional, Tuple
from typing import Dict, Sequence

import sqlalchemy
from snowflake.sqlalchemy import URL
Expand Down Expand Up @@ -42,7 +42,7 @@ def __init__(

def get_columns_from_table(
self, target_table: str, target_schema: str, use_describe: bool = False
) -> Optional[List[Tuple[Any]]]:
) -> Sequence[str]:

# if user wants to use describe (more preformant but with caveat) method
# we re-implement column describe since snowflake.sqlalchemy is shit.
Expand Down
32 changes: 26 additions & 6 deletions dbt_sugar/core/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from dbt_sugar.core.logger import GLOBAL_LOGGER as logger
from dbt_sugar.core.logger import log_manager
from dbt_sugar.core.task.audit import AuditTask
from dbt_sugar.core.task.bootstrap import BootstrapTask
from dbt_sugar.core.task.doc import DocumentationTask
from dbt_sugar.core.ui.traceback_manager import DbtSugarTracebackManager
from dbt_sugar.core.utils import check_and_compare_version
Expand All @@ -26,7 +27,7 @@ def check_and_print_version() -> str:
Returns:
str: version info message ready for printing
"""
needs_update, latest_version = check_and_compare_version()
_, latest_version = check_and_compare_version()
installed_version_message = f"Installed dbt-sugar version: {__version__}".rjust(40)
latest_version_message = f"Latest dbt-sugar version: {latest_version}".rjust(40)
if latest_version:
Expand Down Expand Up @@ -72,7 +73,7 @@ def check_and_print_version() -> str:
# Task-specific argument sub parsers
sub_parsers = parser.add_subparsers(title="Available dbt-sugar commands", dest="command")

# document task parser
# DOC task parser
document_sub_parser = sub_parsers.add_parser(
"doc", parents=[base_subparser], help="Runs documentation and test enforement task."
)
Expand Down Expand Up @@ -101,7 +102,6 @@ def check_and_print_version() -> str:
type=str,
default=str(),
)
# document_sub_parser.add_argument(

document_sub_parser.add_argument(
"--no-ask-tests",
Expand Down Expand Up @@ -143,7 +143,7 @@ def check_and_print_version() -> str:
default=False,
)

# document task parser
# ##### AUDIT Task
audit_sub_parser = sub_parsers.add_parser(
"audit", parents=[base_subparser], help="Runs audit task."
)
Expand All @@ -157,9 +157,17 @@ def check_and_print_version() -> str:
required=False,
)

# task handler

# ##### BOOTSTRAP Task Arg parser
bootstrap_sub_parser = sub_parsers.add_parser(
"bootstrap",
parents=[base_subparser],
help="Runs the bootstrap task, which creates model descriptor files for all your models.",
)
bootstrap_sub_parser.set_defaults(cls=BootstrapTask, which="bootstrap")


# task handler
def handle(
parser: argparse.ArgumentParser,
test_cli_args: List[str] = list(),
Expand Down Expand Up @@ -210,10 +218,22 @@ def handle(

if flag_parser.task == "audit":
audit_task: AuditTask = AuditTask(
flag_parser, dbt_project._project_dir, sugar_config=sugar_config
flag_parser,
dbt_project._project_dir,
sugar_config=sugar_config,
dbt_profile=dbt_profile,
)
return audit_task.run()

if flag_parser.task == "bootstrap":
bootstrap_task: BootstrapTask = BootstrapTask(
flags=flag_parser,
dbt_path=dbt_project._project_dir,
sugar_config=sugar_config,
dbt_profile=dbt_profile,
)
return bootstrap_task.run()

raise NotImplementedError(f"{flag_parser.task} is not supported.")


Expand Down
13 changes: 11 additions & 2 deletions dbt_sugar/core/task/audit.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from rich.console import Console
from rich.table import Table

from dbt_sugar.core.clients.dbt import DbtProfile
from dbt_sugar.core.clients.yaml_helpers import open_yaml
from dbt_sugar.core.config.config import DbtSugarConfig
from dbt_sugar.core.flags import FlagParser
Expand All @@ -22,9 +23,17 @@ class AuditTask(BaseTask):
Holds methods and attrs necessary to audit a model or a dbt project.
"""

def __init__(self, flags: FlagParser, dbt_path: Path, sugar_config: DbtSugarConfig) -> None:
def __init__(
self,
flags: FlagParser,
dbt_path: Path,
sugar_config: DbtSugarConfig,
dbt_profile: DbtProfile,
) -> None:
self.dbt_path = dbt_path
super().__init__(flags=flags, dbt_path=self.dbt_path, sugar_config=sugar_config)
super().__init__(
flags=flags, dbt_path=self.dbt_path, sugar_config=sugar_config, dbt_profile=dbt_profile
)
self.column_update_payload: Dict[str, Dict[str, Any]] = {}
self._flags = flags
self.model_name = self._flags.model
Expand Down
30 changes: 28 additions & 2 deletions dbt_sugar/core/task/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,13 @@
import os
import re
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple
from typing import Any, Dict, List, Optional, Tuple, Union

from dbt_sugar.core.clients.dbt import DbtProfile
from dbt_sugar.core.clients.yaml_helpers import open_yaml, save_yaml
from dbt_sugar.core.config.config import DbtSugarConfig
from dbt_sugar.core.connectors.postgres_connector import PostgresConnector
from dbt_sugar.core.connectors.snowflake_connector import SnowflakeConnector
from dbt_sugar.core.flags import FlagParser
from dbt_sugar.core.logger import GLOBAL_LOGGER as logger

Expand All @@ -16,13 +19,26 @@
DEFAULT_EXCLUDED_YML_FILES = r"dbt_project.yml|packages.yml"


DB_CONNECTORS = {
"postgres": PostgresConnector,
"snowflake": SnowflakeConnector,
}


class BaseTask(abc.ABC):
"""Sets up basic API for task-like classes."""

def __init__(self, flags: FlagParser, dbt_path: Path, sugar_config: DbtSugarConfig) -> None:
def __init__(
self,
flags: FlagParser,
dbt_path: Path,
sugar_config: DbtSugarConfig,
dbt_profile: DbtProfile,
) -> None:
self.repository_path = dbt_path
self._sugar_config = sugar_config
self._flags = flags
self._dbt_profile = dbt_profile

# populated by class methods
self._excluded_folders_from_search_pattern: str = self.setup_paths_exclusion()
Expand All @@ -31,6 +47,16 @@ def __init__(self, flags: FlagParser, dbt_path: Path, sugar_config: DbtSugarConf
self.dbt_tests: Dict[str, List[Dict[str, Any]]] = {}
self.build_descriptions_dictionary()

def get_connector(self) -> Union[PostgresConnector, SnowflakeConnector]:
dbt_credentials = self._dbt_profile.profile
connector = DB_CONNECTORS.get(dbt_credentials.get("type", ""))
if not connector:
raise NotImplementedError(
f"Connector '{dbt_credentials.get('type')}' is not implemented."
)

return connector(dbt_credentials)

def setup_paths_exclusion(self) -> str:
"""Appends excluded_folders to the default folder exclusion patten."""
if self._sugar_config.dbt_project_info["excluded_folders"]:
Expand Down
112 changes: 112 additions & 0 deletions dbt_sugar/core/task/bootstrap.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
"""Bootstrap module. Generates placeholders for all models in a dbt project."""


import functools
import operator
import os
import re
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, List, Sequence, Union

from dbt_sugar.core.clients.dbt import DbtProfile
from dbt_sugar.core.clients.yaml_helpers import open_yaml, save_yaml
from dbt_sugar.core.config.config import DbtSugarConfig
from dbt_sugar.core.flags import FlagParser
from dbt_sugar.core.task.doc import DocumentationTask


@dataclass
class DbtModelsDict:
"""Data class for dbt model info.
We make it a dataclass instead of a dict because the types
inside a dict would have been too messy and was upstetting type checkers.
"""

model_name: str
model_path: Path
model_columns: Sequence[str]


class BootstrapTask(DocumentationTask):
"""Sets up methods and orchestration of the bootstrap task.
The bootstrap task is a task that iterates through all the models
in a dbt project, checks the tables exist on the db, and generates
placeholder model descriptor files (schema.yml) for any column or models
that have not yet been documented.
"""

def __init__(
self,
flags: FlagParser,
dbt_path: Path,
sugar_config: DbtSugarConfig,
dbt_profile: DbtProfile,
) -> None:
# we specifically run the super init because we need to populate the cache
# of all dbt models, where they live etc
super().__init__(
flags=flags, dbt_profile=dbt_profile, config=sugar_config, dbt_path=dbt_path
)
self.dbt_models_dict: Dict[str, Union[Path, List[str]]] = {}
self._dbt_profile = dbt_profile
self.schema = self._dbt_profile.profile.get("target_schema", "")

self.dbt_models_data: List[DbtModelsDict] = []

def build_all_models_dict(self) -> None:
"""Walk through all .sql files and load their info (name, path etc) into a dict."""
_dbt_models_data = []
for root, _, files in os.walk(self.repository_path):
if not re.search(self._excluded_folders_from_search_pattern, root):
_dbt_models_data.append(
[
DbtModelsDict(
model_name=f.replace(".sql", ""),
model_path=Path(root, f),
model_columns=[],
)
for f in files
if f.lower().endswith(".sql")
and f.lower().replace(".sql", "")
not in self._sugar_config.dbt_project_info.get("excluded_models", [])
]
)
self.dbt_models_data = functools.reduce(operator.iconcat, _dbt_models_data, [])

def add_or_update_model_descriptor_placeholders(self, is_test: bool = False):
connector = self.get_connector()
for model_info in self.dbt_models_data:
model_descriptor_content = {}
model_info.model_columns = connector.get_columns_from_table(
model_info.model_name,
self.schema,
use_describe=self._sugar_config.dbt_project_info.get(
"use_describe_snowflake", False
),
)
(
model_descriptor_path,
descriptor_file_exists,
is_already_documented,
) = self.find_model_schema_file(model_name=model_info.model_name)
if descriptor_file_exists and model_descriptor_path:
model_descriptor_content = open_yaml(model_descriptor_path)

model_descriptor_content = self.create_or_update_model_entry(
is_already_documented,
model_descriptor_content,
model_name=model_info.model_name,
columns_sql=model_info.model_columns,
)
if is_test:
return self.order_schema_yml(model_descriptor_content)
if model_descriptor_path:
save_yaml(model_descriptor_path, self.order_schema_yml(model_descriptor_content))

def run(self) -> int:
self.build_all_models_dict()
self.add_or_update_model_descriptor_placeholders()
return 0
Loading

0 comments on commit 9d841e8

Please sign in to comment.