Skip to content

Commit

Permalink
feat support another blocks and template expression
Browse files Browse the repository at this point in the history
  • Loading branch information
soundTricker committed Oct 25, 2024
1 parent 05625a2 commit ff7a79b
Show file tree
Hide file tree
Showing 2 changed files with 146 additions and 50 deletions.
56 changes: 32 additions & 24 deletions sqlfluff_templater_dataform/templater.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,27 @@
import enum
import logging
import os
import os.path
import re
import uuid
from typing import (
Iterator,
List,
Optional,
)
from sqlfluff.core.templaters.base import RawTemplater, TemplatedFile, large_file_check, RawFileSlice, TemplatedFileSlice
Optional, Tuple, )

from sqlfluff.cli.formatters import OutputStreamFormatter
from sqlfluff.core import FluffConfig
from sqlfluff.core.errors import SQLFluffSkipFile

from sqlfluff.core.templaters.base import RawTemplater, TemplatedFile, large_file_check, RawFileSlice, \
TemplatedFileSlice

# Instantiate the templater logger
templater_logger = logging.getLogger("sqlfluff.templater")

class UsedJSBlockError(SQLFluffSkipFile):
""" This package does not support dataform js block """
""" When js block used, skip linting a file."""
pass

class DataformTemplater(RawTemplater):
"""A templater using dataform."""

name = "dataform"
sequential_fail_limit = 3
adapters = {}

def __init__(self, **kwargs):
self.sqlfluff_config = None
Expand Down Expand Up @@ -59,10 +54,6 @@ def process(
config: Optional["FluffConfig"] = None,
formatter: Optional["OutputStreamFormatter"] = None,
):
templater_logger.info(in_str)
if in_str and self.has_js_block(in_str):
raise UsedJSBlockError("JavaScript block is not supported.")

templated_sql, raw_slices, templated_slices = self.slice_sqlx_template(in_str)

return TemplatedFile(
Expand All @@ -73,12 +64,8 @@ def process(
raw_sliced=raw_slices,
), []

def has_js_block(self, sql: str) -> bool:
pattern = re.compile(r'js\s*\{(?:[^{}]|\{[^{}]*\})*\}', re.DOTALL)
return bool(pattern.search(sql))

def replace_blocks(self, in_str: str) -> str:
pattern = re.compile(r'config\s*\{(?:[^{}]|\{[^{}]*\})*\}', re.DOTALL)
def replace_blocks(self, in_str: str, block_name: str) -> str:
pattern = re.compile(block_name + r'\s*\{(?:[^{}]|\{[^{}]*\})*\}', re.DOTALL)
return re.sub(pattern, '', in_str)

def replace_ref_with_bq_table(self, sql):
Expand All @@ -95,17 +82,38 @@ def ref_to_table(match):

return re.sub(pattern, ref_to_table, sql)

def extract_templates(self, sql):
pattern = re.compile(r'\$\s*\{(?:[^{}]|\{[^{}]*})*}', re.DOTALL)
return [m.group() for m in re.finditer(pattern, sql)]

def replace_templates(self, sql):
remaining_text = sql
expressions = self.extract_templates(sql)
for expression in expressions:
# https://github.com/sqlfluff/sqlfluff/issues/1540#issuecomment-1110835283
mask_string = "a" + str(uuid.uuid1()).replace("-", ".a")
remaining_text = remaining_text.replace(expression, mask_string)
return remaining_text

# SQLX をスライスして、RawFileSlice と TemplatedFileSlice を同時に返す関数
def slice_sqlx_template(self, sql: str) -> (str, List[RawFileSlice], List[TemplatedFileSlice]):
def slice_sqlx_template(self, sql: str) -> Tuple[str, List[RawFileSlice], List[TemplatedFileSlice]]:
# config や js ブロックを改行に置換
replaced_sql = self.replace_blocks(sql)
replaced_sql = sql
for block_name in ['config', 'js', 'pre_operations', 'post_operations']:
replaced_sql = self.replace_blocks(replaced_sql, block_name)

# ref 関数をBigQueryテーブル名に置換
replaced_sql = self.replace_ref_with_bq_table(replaced_sql)

# ${} を置換
replaced_sql = self.replace_templates(replaced_sql)

# SQLX の構造に対応する正規表現パターン
patterns = [
(r'config\s*\{(?:[^{}]|\{[^{}]*\})*\}', 'templated'), # config ブロック
# (r'js\s*\{(?:[^{}]|\{[^{}]*\})*\}', 'templated'), # js ブロック
(r'js\s*\{(?:[^{}]|\{[^{}]*\})*\}', 'templated'), # js ブロック
(r'pre_operations\s*\{(?:[^{}]|\{[^{}]*\})*\}', 'templated'), # pre_operations ブロック
(r'post_operations\s*\{(?:[^{}]|\{[^{}]*\})*\}', 'templated'), # post_operations ブロック
(r'\$\{\s*ref\(\s*\'([^\']+)\'(?:\s*,\s*\'([^\']+)\')?\s*\)\s*\}', 'templated') # ref 関数
]

Expand Down
140 changes: 114 additions & 26 deletions test/templater_test.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,5 @@
"""Tests for the dataform templater."""

def test_has_js_block(templater):
has_js_sql = """config {
type: "table",
columns: {
"test" : "test",
"value:: "value"
}
}
js {
const myVar = "test";
}
SELECT * FROM my_table"""
not_has_js_sql = """config {
type: "table",
columns: {
"test" : "test",
"value:: "value"
}
}
SELECT * FROM my_table"""
assert templater.has_js_block(has_js_sql) == True
assert templater.has_js_block(not_has_js_sql) == False
import re

def test_replace_ref_with_bq_table_single_ref(templater):
input_sql = "SELECT * FROM ${ref('test')}"
Expand All @@ -41,7 +19,24 @@ def test_replace_ref_with_bq_table_multiple_refs(templater):
result = templater.replace_ref_with_bq_table(input_sql)
assert result == expected_sql

def test_replace_blocks_single_block(templater):
def test_replace_templates_without_linebreak(templater):
input_sql = "SELECT * FROM ${hoge}, ${hoge.fuga('another')}"
expected_sql = r"SELECT \* FROM a.+, a.+"
result = templater.replace_templates(input_sql)
assert re.match(expected_sql, result)

def test_replace_templates_with_linebreak(templater):
input_sql = """SELECT * FROM ${
hoge
}, ${
hoge.fuga('another')
}"""
expected_sql = r"""SELECT \* FROM a.+, a.+"""
result = templater.replace_templates(input_sql)
assert re.match(expected_sql, result)


def test_replace_blocks_single_block_config(templater):
input_sql = """config {
type: "table",
columns: {
Expand All @@ -52,13 +47,48 @@ def test_replace_blocks_single_block(templater):
SELECT * FROM my_table"""

expected_sql = "\nSELECT * FROM my_table"
result = templater.replace_blocks(input_sql)
result = templater.replace_blocks(input_sql, "config")
assert result == expected_sql

def test_replace_blocks_single_block_js(templater):
input_sql = """js {
var hoge = "fuga"
}
SELECT * FROM my_table"""

expected_sql = "\nSELECT * FROM my_table"
result = templater.replace_blocks(input_sql, "js")
assert result == expected_sql

def test_replace_blocks_single_block_pre_operations(templater):
input_sql = """pre_operations {
CREATE TEMP FUNCTION AddFourAndDivide(x INT64, y INT64)
RETURNS FLOAT64
AS ((x + 4) / y);
}
SELECT * FROM my_table"""

expected_sql = "\nSELECT * FROM my_table"
result = templater.replace_blocks(input_sql, "pre_operations")
assert result == expected_sql

def test_replace_blocks_single_block_post_operations(templater):
input_sql = """post_operations {
GRANT `roles/bigquery.dataViewer`
ON
TABLE ${self()}
TO "group:allusers@example.com", "user:otheruser@example.com"
}
SELECT * FROM my_table"""

expected_sql = "\nSELECT * FROM my_table"
result = templater.replace_blocks(input_sql, "post_operations")
assert result == expected_sql

def test_replace_blocks_no_block(templater):
input_sql = "SELECT * FROM my_table"
expected_sql = "SELECT * FROM my_table"
result = templater.replace_blocks(input_sql)
result = templater.replace_blocks(input_sql, "config")
assert result == expected_sql

# slice_sqlx_template のテスト
Expand Down Expand Up @@ -116,6 +146,64 @@ def test_slice_sqlx_template_with_multiple_refs(templater):
assert templated_slices[4].slice_type == "templated"
assert templated_slices[5].slice_type == "literal"


def test_slice_sqlx_template_with_full_expression_query(templater):
input_sqlx = """config {
type: "view",
columns: {
"test" : "test",
"value:: "value"
}
}
js {
var hoge = "fuga"
}
pre_operations {
CREATE TEMP FUNCTION AddFourAndDivide(x INT64, y INT64)
RETURNS FLOAT64
AS ((x + 4) / y);
}
post_operations {
GRANT `roles/bigquery.dataViewer`
ON
TABLE ${self()}
TO "group:allusers@example.com", "user:otheruser@example.com"
}
SELECT * FROM ${ref('test')} JOIN ${ref('other_table')} ON test.id = other_table.id AND test.name = ${hoge}
"""
expected_sql = r"\s+SELECT \* FROM `my_project\.my_dataset\.test` JOIN `my_project\.my_dataset\.other_table` ON test\.id = other_table\.id AND test\.name = a.+\n"

replaced_sql, raw_slices, templated_slices = templater.slice_sqlx_template(input_sqlx)

assert re.match(expected_sql, replaced_sql)

assert len(raw_slices) == 12
assert raw_slices[0].raw.startswith("config")
assert raw_slices[2].raw.startswith("js")
assert raw_slices[4].raw.startswith("pre_operations")
assert raw_slices[6].raw.startswith("post_operations")
assert raw_slices[7].raw.startswith("\n\nSELECT *")
assert raw_slices[8].raw.startswith('${ref')
assert raw_slices[9].raw.startswith(' JOIN')
assert raw_slices[10].raw.startswith('${ref')
assert re.match(r" ON test.id = other_table.id AND test.name = \$\{hoge}\n", raw_slices[11].raw)

assert len(templated_slices) == 12
assert templated_slices[0].slice_type == "templated"
assert templated_slices[1].slice_type == "literal"
assert templated_slices[2].slice_type == "templated"
assert templated_slices[3].slice_type == "literal"
assert templated_slices[4].slice_type == "templated"
assert templated_slices[5].slice_type == "literal"
assert templated_slices[6].slice_type == "templated"
assert templated_slices[7].slice_type == "literal"
assert templated_slices[8].slice_type == "templated"
assert templated_slices[9].slice_type == "literal"
assert templated_slices[10].slice_type == "templated"
assert templated_slices[11].slice_type == "literal"


def test_slice_sqlx_template_with_no_ref(templater):
input_sqlx = """SELECT * FROM my_table WHERE true
"""
Expand Down

0 comments on commit ff7a79b

Please sign in to comment.