diff --git a/.changes/unreleased/Under the Hood-20240309-141054.yaml b/.changes/unreleased/Under the Hood-20240309-141054.yaml new file mode 100644 index 00000000000..4dff658a8c1 --- /dev/null +++ b/.changes/unreleased/Under the Hood-20240309-141054.yaml @@ -0,0 +1,6 @@ +kind: Under the Hood +body: Improve dbt CLI speed +time: 2024-03-09T14:10:54.549618-05:00 +custom: + Author: dwreeves + Issue: "4627" diff --git a/core/dbt/artifacts/schemas/run/v5/run.py b/core/dbt/artifacts/schemas/run/v5/run.py index eb731b71b5d..e8b5d1ddf36 100644 --- a/core/dbt/artifacts/schemas/run/v5/run.py +++ b/core/dbt/artifacts/schemas/run/v5/run.py @@ -1,6 +1,5 @@ import threading -from typing import Any, Optional, Iterable, Tuple, Sequence, Dict -import agate +from typing import Any, Optional, Iterable, Tuple, Sequence, Dict, TYPE_CHECKING from dataclasses import dataclass, field from datetime import datetime @@ -22,9 +21,13 @@ from dbt_common.clients.system import write_json +if TYPE_CHECKING: + import agate + + @dataclass class RunResult(NodeResult): - agate_table: Optional[agate.Table] = field( + agate_table: Optional["agate.Table"] = field( default=None, metadata={"serialize": lambda x: None, "deserialize": lambda x: None} ) diff --git a/core/dbt/cli/main.py b/core/dbt/cli/main.py index 5dbeae13697..d1b9eebbf3d 100644 --- a/core/dbt/cli/main.py +++ b/core/dbt/cli/main.py @@ -20,24 +20,6 @@ from dbt.artifacts.schemas.catalog import CatalogArtifact from dbt.artifacts.schemas.run import RunExecutionResult from dbt_common.events.base_types import EventMsg -from dbt.task.build import BuildTask -from dbt.task.clean import CleanTask -from dbt.task.clone import CloneTask -from dbt.task.compile import CompileTask -from dbt.task.debug import DebugTask -from dbt.task.deps import DepsTask -from dbt.task.docs.generate import GenerateTask -from dbt.task.docs.serve import ServeTask -from dbt.task.freshness import FreshnessTask -from dbt.task.init import InitTask -from dbt.task.list import ListTask -from dbt.task.retry import RetryTask -from dbt.task.run import RunTask -from dbt.task.run_operation import RunOperationTask -from dbt.task.seed import SeedTask -from dbt.task.show import ShowTask -from dbt.task.snapshot import SnapshotTask -from dbt.task.test import TestTask @dataclass @@ -211,6 +193,8 @@ def cli(ctx, **kwargs): @requires.manifest def build(ctx, **kwargs): """Run all seeds, models, snapshots, and tests in DAG order""" + from dbt.task.build import BuildTask + task = BuildTask( ctx.obj["flags"], ctx.obj["runtime_config"], @@ -239,6 +223,8 @@ def build(ctx, **kwargs): @requires.project def clean(ctx, **kwargs): """Delete all folders in the clean-targets list (usually the dbt_packages and target directories.)""" + from dbt.task.clean import CleanTask + task = CleanTask(ctx.obj["flags"], ctx.obj["project"]) results = task.run() @@ -279,6 +265,8 @@ def docs(ctx, **kwargs): @requires.manifest(write=False) def docs_generate(ctx, **kwargs): """Generate the documentation website for your project""" + from dbt.task.docs.generate import GenerateTask + task = GenerateTask( ctx.obj["flags"], ctx.obj["runtime_config"], @@ -309,6 +297,8 @@ def docs_generate(ctx, **kwargs): @requires.runtime_config def docs_serve(ctx, **kwargs): """Serve the documentation website for your project""" + from dbt.task.docs.serve import ServeTask + task = ServeTask( ctx.obj["flags"], ctx.obj["runtime_config"], @@ -348,6 +338,8 @@ def docs_serve(ctx, **kwargs): def compile(ctx, **kwargs): """Generates executable SQL from source, model, test, and analysis files. Compiled SQL files are written to the target/ directory.""" + from dbt.task.compile import CompileTask + task = CompileTask( ctx.obj["flags"], ctx.obj["runtime_config"], @@ -387,6 +379,8 @@ def compile(ctx, **kwargs): def show(ctx, **kwargs): """Generates executable SQL for a named resource or inline query, runs that SQL, and returns a preview of the results. Does not materialize anything to the warehouse.""" + from dbt.task.show import ShowTask + task = ShowTask( ctx.obj["flags"], ctx.obj["runtime_config"], @@ -413,6 +407,7 @@ def show(ctx, **kwargs): @requires.preflight def debug(ctx, **kwargs): """Show information on the current dbt environment and check dependencies, then test the database connection. Not to be confused with the --debug option which increases verbosity.""" + from dbt.task.debug import DebugTask task = DebugTask( ctx.obj["flags"], @@ -452,6 +447,8 @@ def deps(ctx, **kwargs): There is a way to add new packages by providing an `--add-package` flag to deps command which will allow user to specify a package they want to add in the format of packagename@version. """ + from dbt.task.deps import DepsTask + flags = ctx.obj["flags"] if flags.ADD_PACKAGE: if not flags.ADD_PACKAGE["version"] and flags.SOURCE != "local": @@ -481,6 +478,8 @@ def deps(ctx, **kwargs): @requires.preflight def init(ctx, **kwargs): """Initialize a new dbt project.""" + from dbt.task.init import InitTask + task = InitTask(ctx.obj["flags"], None) results = task.run() @@ -514,6 +513,8 @@ def init(ctx, **kwargs): @requires.manifest def list(ctx, **kwargs): """List the resources in your project""" + from dbt.task.list import ListTask + task = ListTask( ctx.obj["flags"], ctx.obj["runtime_config"], @@ -578,6 +579,8 @@ def parse(ctx, **kwargs): @requires.manifest def run(ctx, **kwargs): """Compile SQL and execute against the current target database.""" + from dbt.task.run import RunTask + task = RunTask( ctx.obj["flags"], ctx.obj["runtime_config"], @@ -608,6 +611,8 @@ def run(ctx, **kwargs): @requires.runtime_config def retry(ctx, **kwargs): """Retry the nodes that failed in the previous run.""" + from dbt.task.retry import RetryTask + # Retry will parse manifest inside the task after we consolidate the flags task = RetryTask( ctx.obj["flags"], @@ -644,6 +649,8 @@ def retry(ctx, **kwargs): @requires.postflight def clone(ctx, **kwargs): """Create clones of selected nodes based on their location in the manifest provided to --state.""" + from dbt.task.clone import CloneTask + task = CloneTask( ctx.obj["flags"], ctx.obj["runtime_config"], @@ -676,6 +683,8 @@ def clone(ctx, **kwargs): @requires.manifest def run_operation(ctx, **kwargs): """Run the named macro with any supplied arguments.""" + from dbt.task.run_operation import RunOperationTask + task = RunOperationTask( ctx.obj["flags"], ctx.obj["runtime_config"], @@ -711,6 +720,8 @@ def run_operation(ctx, **kwargs): @requires.manifest def seed(ctx, **kwargs): """Load data from csv files into your data warehouse.""" + from dbt.task.seed import SeedTask + task = SeedTask( ctx.obj["flags"], ctx.obj["runtime_config"], @@ -743,6 +754,8 @@ def seed(ctx, **kwargs): @requires.manifest def snapshot(ctx, **kwargs): """Execute snapshots defined in your project""" + from dbt.task.snapshot import SnapshotTask + task = SnapshotTask( ctx.obj["flags"], ctx.obj["runtime_config"], @@ -785,6 +798,8 @@ def source(ctx, **kwargs): @requires.manifest def freshness(ctx, **kwargs): """check the current freshness of the project's sources""" + from dbt.task.freshness import FreshnessTask + task = FreshnessTask( ctx.obj["flags"], ctx.obj["runtime_config"], @@ -825,6 +840,8 @@ def freshness(ctx, **kwargs): @requires.manifest def test(ctx, **kwargs): """Runs tests on data in deployed models. Run this after `dbt run`""" + from dbt.task.test import TestTask + task = TestTask( ctx.obj["flags"], ctx.obj["runtime_config"], diff --git a/core/dbt/context/providers.py b/core/dbt/context/providers.py index e3520bc780a..90f0a508c8c 100644 --- a/core/dbt/context/providers.py +++ b/core/dbt/context/providers.py @@ -13,6 +13,7 @@ Iterable, Mapping, Tuple, + TYPE_CHECKING, ) from typing_extensions import Protocol @@ -22,7 +23,6 @@ from dbt_common.clients.jinja import MacroProtocol from dbt_common.context import get_invocation_context from dbt.adapters.factory import get_adapter, get_adapter_package_names, get_adapter_type_names -from dbt_common.clients import agate_helper from dbt.clients.jinja import get_rendered, MacroGenerator, MacroStack, UnitTestMacroGenerator from dbt.config import RuntimeConfig, Project from dbt.constants import SECRET_ENV_PREFIX, DEFAULT_ENV_PLACEHOLDER @@ -82,7 +82,8 @@ from dbt_common.utils import merge, AttrDict, cast_to_str from dbt import selected_resources -import agate +if TYPE_CHECKING: + import agate _MISSING = object() @@ -851,8 +852,10 @@ def load_result(self, name: str) -> Optional[AttrDict]: @contextmember() def store_result( - self, name: str, response: Any, agate_table: Optional[agate.Table] = None + self, name: str, response: Any, agate_table: Optional["agate.Table"] = None ) -> str: + from dbt_common.clients import agate_helper + if agate_table is None: agate_table = agate_helper.empty_table() @@ -872,7 +875,7 @@ def store_raw_result( message=Optional[str], code=Optional[str], rows_affected=Optional[str], - agate_table: Optional[agate.Table] = None, + agate_table: Optional["agate.Table"] = None, ) -> str: response = AdapterResponse(_message=message, code=code, rows_affected=rows_affected) return self.store_result(name, response, agate_table) @@ -921,7 +924,9 @@ def try_or_compiler_error( raise CompilationError(message_if_exception, self.model) @contextmember() - def load_agate_table(self) -> agate.Table: + def load_agate_table(self) -> "agate.Table": + from dbt_common.clients import agate_helper + if not isinstance(self.model, SeedNode): raise LoadAgateTableNotSeedError(self.model.resource_type, node=self.model) diff --git a/core/dbt/exceptions.py b/core/dbt/exceptions.py index 25e4c055f14..8b10a57a7f2 100644 --- a/core/dbt/exceptions.py +++ b/core/dbt/exceptions.py @@ -1,8 +1,7 @@ import json import re import io -import agate -from typing import Any, Dict, List, Mapping, Optional, Union +from typing import Any, Dict, List, Mapping, Optional, Union, TYPE_CHECKING from dbt_common.exceptions import ( DbtRuntimeError, @@ -19,6 +18,10 @@ from dbt_common.dataclass_schema import ValidationError +if TYPE_CHECKING: + import agate + + class ContractBreakingChangeError(DbtRuntimeError): CODE = 10016 MESSAGE = "Breaking Change to Contract" @@ -1349,7 +1352,7 @@ def __init__(self, yaml_columns, sql_columns): self.sql_columns = sql_columns super().__init__(msg=self.get_message()) - def get_mismatches(self) -> agate.Table: + def get_mismatches(self) -> "agate.Table": # avoid a circular import from dbt_common.clients.agate_helper import table_from_data_flat @@ -1400,7 +1403,7 @@ def get_message(self) -> str: "This model has an enforced contract, and its 'columns' specification is missing" ) - table: agate.Table = self.get_mismatches() + table: "agate.Table" = self.get_mismatches() # Hack to get Agate table output as string output = io.StringIO() table.print_table(output=output, max_rows=None, max_column_width=50) # type: ignore diff --git a/core/dbt/task/run_operation.py b/core/dbt/task/run_operation.py index 7b217853ce7..1c6c5002e27 100644 --- a/core/dbt/task/run_operation.py +++ b/core/dbt/task/run_operation.py @@ -2,8 +2,7 @@ import threading import traceback from datetime import datetime - -import agate +from typing import TYPE_CHECKING import dbt_common.exceptions from dbt.adapters.factory import get_adapter @@ -24,6 +23,10 @@ RESULT_FILE_NAME = "run_results.json" +if TYPE_CHECKING: + import agate + + class RunOperationTask(ConfiguredTask): def _get_macro_parts(self): macro_name = self.args.macro @@ -34,7 +37,7 @@ def _get_macro_parts(self): return package_name, macro_name - def _run_unsafe(self, package_name, macro_name) -> agate.Table: + def _run_unsafe(self, package_name, macro_name) -> "agate.Table": adapter = get_adapter(self.config) macro_kwargs = self.args.args diff --git a/core/dbt/task/test.py b/core/dbt/task/test.py index 0754ca2277f..d3b16b88bc0 100644 --- a/core/dbt/task/test.py +++ b/core/dbt/task/test.py @@ -1,4 +1,3 @@ -import agate import daff import io import json @@ -8,7 +7,7 @@ from dbt_common.events.format import pluralize from dbt_common.dataclass_schema import dbtClassMixin import threading -from typing import Dict, Any, Optional, Union, List +from typing import Dict, Any, Optional, Union, List, TYPE_CHECKING from .compile import CompileRunner from .run import RunTask @@ -37,6 +36,10 @@ from dbt_common.ui import green, red +if TYPE_CHECKING: + import agate + + @dataclass class UnitTestDiff(dbtClassMixin): actual: List[Dict[str, Any]] @@ -325,7 +328,7 @@ def _get_unit_test_agate_table(self, result_table, actual_or_expected: str): return unit_test_table.select(columns) def _get_daff_diff( - self, expected: agate.Table, actual: agate.Table, ordered: bool = False + self, expected: "agate.Table", actual: "agate.Table", ordered: bool = False ) -> daff.TableDiff: expected_daff_table = daff.PythonTableView(list_rows_from_table(expected)) @@ -388,7 +391,7 @@ def get_runner_type(self, _): # This was originally in agate_helper, but that was moved out into dbt_common -def json_rows_from_table(table: agate.Table) -> List[Dict[str, Any]]: +def json_rows_from_table(table: "agate.Table") -> List[Dict[str, Any]]: "Convert a table to a list of row dict objects" output = io.StringIO() table.to_json(path=output) # type: ignore @@ -397,7 +400,7 @@ def json_rows_from_table(table: agate.Table) -> List[Dict[str, Any]]: # This was originally in agate_helper, but that was moved out into dbt_common -def list_rows_from_table(table: agate.Table) -> List[Any]: +def list_rows_from_table(table: "agate.Table") -> List[Any]: "Convert a table to a list of lists, where the first element represents the header" rows = [[col.name for col in table.columns]] for row in table.rows: