Skip to content

Commit

Permalink
Add support for validate_sql method to BigQuery (#819)
Browse files Browse the repository at this point in the history
In CLI contexts MetricFlow will issue dry run queries as part
of its warehouse validation operations, and so we are adding a
validate_sql method to all adapters.

This commit adds support for the validate_sql method to BigQuery. It
does so by creating a BigQuery-specific `dry_run` method on the
BigQueryConnectionManager. This simply passes through the input SQL
with the `dry_run` QueryJobParameter flag set True. This will result
in BigQuery computing and returning a cost estimate for the query,
or raising an exception in the event the query is not valid.

Note: constructing the response object involves some repetitive value
extraction from the QueryResult returned by BigQuery. While I would
ordinariy prefer to tidy this up first we are pressed for time, and so
we postpone that cleanup in order to keep this change as isolated
as possible.
  • Loading branch information
tlento committed Jul 12, 2023
1 parent 3c34dbb commit d0d593e
Show file tree
Hide file tree
Showing 4 changed files with 86 additions and 2 deletions.
6 changes: 6 additions & 0 deletions .changes/unreleased/Features-20230712-014350.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
kind: Features
body: Add validate_sql to BigQuery adapter and dry_run to BigQueryConnectionManager
time: 2023-07-12T01:43:50.36167-04:00
custom:
Author: tlento
Issue: "805"
45 changes: 43 additions & 2 deletions dbt/adapters/bigquery/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,7 +428,13 @@ def get_table_from_response(cls, resp):
column_names = [field.name for field in resp.schema]
return agate_helper.table_from_data_flat(resp, column_names)

def raw_execute(self, sql, use_legacy_sql=False, limit: Optional[int] = None):
def raw_execute(
self,
sql,
use_legacy_sql=False,
limit: Optional[int] = None,
dry_run: bool = False,
):
conn = self.get_thread_connection()
client = conn.handle

Expand All @@ -446,7 +452,11 @@ def raw_execute(self, sql, use_legacy_sql=False, limit: Optional[int] = None):
if active_user:
labels["dbt_invocation_id"] = active_user.invocation_id

job_params = {"use_legacy_sql": use_legacy_sql, "labels": labels}
job_params = {
"use_legacy_sql": use_legacy_sql,
"labels": labels,
"dry_run": dry_run,
}

priority = conn.credentials.priority
if priority == Priority.Batch:
Expand Down Expand Up @@ -554,6 +564,37 @@ def execute(

return response, table

def dry_run(self, sql: str) -> BigQueryAdapterResponse:
"""Run the given sql statement with the `dry_run` job parameter set.
This will allow BigQuery to validate the SQL and immediately return job cost
estimates, which we capture in the BigQueryAdapterResponse. Invalid SQL
will result in an exception.
"""
sql = self._add_query_comment(sql)
query_job, _ = self.raw_execute(sql, dry_run=True)

# TODO: Factor this repetitive block out into a factory method on
# BigQueryAdapterResponse
message = f"Ran dry run query for statement of type {query_job.statement_type}"
bytes_billed = query_job.total_bytes_billed
processed_bytes = self.format_bytes(query_job.total_bytes_processed)
location = query_job.location
project_id = query_job.project
job_id = query_job.job_id
slot_ms = query_job.slot_millis

return BigQueryAdapterResponse(
_message=message,
code="DRY RUN",
bytes_billed=bytes_billed,
bytes_processed=processed_bytes,
location=location,
project_id=project_id,
job_id=job_id,
slot_ms=slot_ms,
)

@staticmethod
def _bq_job_link(location, project_id, job_id) -> str:
return f"https://console.cloud.google.com/bigquery?project={project_id}&j=bq:{location}:{job_id}&page=queryresults"
Expand Down
10 changes: 10 additions & 0 deletions dbt/adapters/bigquery/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import threading
from typing import Dict, List, Optional, Any, Set, Union, Type

from dbt.contracts.connection import AdapterResponse
from dbt.contracts.graph.nodes import ColumnLevelConstraint, ModelLevelConstraint, ConstraintType # type: ignore
from dbt.dataclass_schema import dbtClassMixin, ValidationError

Expand Down Expand Up @@ -1024,3 +1025,12 @@ def render_model_constraint(cls, constraint: ModelLevelConstraint) -> Optional[s
def debug_query(self):
"""Override for DebugTask method"""
self.execute("select 1 as id")

def validate_sql(self, sql: str) -> AdapterResponse:
"""Submit the given SQL to the engine for validation, but not execution.
This submits the query with the `dry_run` flag set True.
:param str sql: The sql to validate
"""
return self.connections.dry_run(sql)
27 changes: 27 additions & 0 deletions tests/functional/adapter/utils/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
import random

import pytest
from google.api_core.exceptions import NotFound

from dbt.tests.adapter.utils.test_array_append import BaseArrayAppend
from dbt.tests.adapter.utils.test_array_concat import BaseArrayConcat
Expand All @@ -24,6 +27,7 @@
from dbt.tests.adapter.utils.test_safe_cast import BaseSafeCast
from dbt.tests.adapter.utils.test_split_part import BaseSplitPart
from dbt.tests.adapter.utils.test_string_literal import BaseStringLiteral
from dbt.tests.adapter.utils.test_validate_sql import BaseValidateSqlMethod
from tests.functional.adapter.utils.fixture_array_append import (
models__array_append_actual_sql,
models__array_append_expected_sql,
Expand Down Expand Up @@ -167,3 +171,26 @@ class TestSplitPart(BaseSplitPart):

class TestStringLiteral(BaseStringLiteral):
pass


class TestValidateSqlMethod(BaseValidateSqlMethod):
pass


class TestDryRunMethod:
"""Test connection manager dry run method operation."""

def test_dry_run_method(self, project) -> None:
"""Test dry run method on a DDL statement.
This allows us to demonstrate that no SQL is executed.
"""
with project.adapter.connection_named("_test"):
client = project.adapter.connections.get_thread_connection().handle
random_suffix = "".join(random.choices([str(i) for i in range(10)], k=10))
table_name = f"test_dry_run_{random_suffix}"
table_id = "{}.{}.{}".format(project.database, project.test_schema, table_name)
res = project.adapter.connections.dry_run(f"CREATE TABLE {table_id} (x INT64)")
assert res.code == "DRY RUN"
with pytest.raises(expected_exception=NotFound):
client.get_table(table_id)

0 comments on commit d0d593e

Please sign in to comment.