From 506f2c939a1599c6ce8ccad9d117c426118207a8 Mon Sep 17 00:00:00 2001 From: Maximilian Roos Date: Sun, 2 Oct 2022 00:59:26 -0700 Subject: [PATCH 1/9] A very-WIP implementation of the PRQL parser --- core/dbt/contracts/files.py | 1 + core/dbt/contracts/graph/compiled.py | 7 ++ core/dbt/contracts/graph/parsed.py | 5 + core/dbt/graph/selector_spec.py | 2 +- core/dbt/node_types.py | 3 + core/dbt/parser/_dbt_prql.py | 115 +++++++++++++++++++++++ core/dbt/parser/base.py | 19 ++++ core/dbt/parser/language_provider.py | 55 +++++++++++ core/dbt/parser/models.py | 21 +++++ core/dbt/parser/read_files.py | 2 +- core/dbt/parser/schemas.py | 1 + test/unit/test_graph.py | 34 +++---- test/unit/test_graph_selector_methods.py | 31 ++++-- test/unit/test_node_types.py | 1 + test/unit/test_parser.py | 49 ++++++++++ 15 files changed, 319 insertions(+), 27 deletions(-) create mode 100644 core/dbt/parser/_dbt_prql.py create mode 100644 core/dbt/parser/language_provider.py diff --git a/core/dbt/contracts/files.py b/core/dbt/contracts/files.py index b915a0d1197..05b1cdb570b 100644 --- a/core/dbt/contracts/files.py +++ b/core/dbt/contracts/files.py @@ -196,6 +196,7 @@ class SourceFile(BaseSourceFile): docs: List[str] = field(default_factory=list) macros: List[str] = field(default_factory=list) env_vars: List[str] = field(default_factory=list) + language: str = "sql" @classmethod def big_seed(cls, path: FilePath) -> "SourceFile": diff --git a/core/dbt/contracts/graph/compiled.py b/core/dbt/contracts/graph/compiled.py index 118d104f537..e7b48ae9268 100644 --- a/core/dbt/contracts/graph/compiled.py +++ b/core/dbt/contracts/graph/compiled.py @@ -10,6 +10,7 @@ ParsedResource, ParsedRPCNode, ParsedSqlNode, + ParsedPrqlNode, ParsedGenericTestNode, ParsedSeedNode, ParsedSnapshotNode, @@ -92,6 +93,11 @@ class CompiledSqlNode(CompiledNode): resource_type: NodeType = field(metadata={"restrict": [NodeType.SqlOperation]}) +@dataclass +class CompiledPrqlNode(CompiledNode): + resource_type: NodeType = field(metadata={"restrict": [NodeType.PrqlOperation]}) + + @dataclass class CompiledSeedNode(CompiledNode): # keep this in sync with ParsedSeedNode! @@ -146,6 +152,7 @@ def same_contents(self, other) -> bool: CompiledHookNode: ParsedHookNode, CompiledRPCNode: ParsedRPCNode, CompiledSqlNode: ParsedSqlNode, + CompiledPrqlNode: ParsedPrqlNode, CompiledSeedNode: ParsedSeedNode, CompiledSnapshotNode: ParsedSnapshotNode, CompiledSingularTestNode: ParsedSingularTestNode, diff --git a/core/dbt/contracts/graph/parsed.py b/core/dbt/contracts/graph/parsed.py index 8fb6c6235aa..19d6eef5af6 100644 --- a/core/dbt/contracts/graph/parsed.py +++ b/core/dbt/contracts/graph/parsed.py @@ -363,6 +363,11 @@ class ParsedSqlNode(ParsedNode): resource_type: NodeType = field(metadata={"restrict": [NodeType.SqlOperation]}) +@dataclass +class ParsedPrqlNode(ParsedNode): + resource_type: NodeType = field(metadata={"restrict": [NodeType.PrqlOperation]}) + + def same_seeds(first: ParsedNode, second: ParsedNode) -> bool: # for seeds, we check the hashes. If the hashes are different types, # no match. If the hashes are both the same 'path', log a warning and diff --git a/core/dbt/graph/selector_spec.py b/core/dbt/graph/selector_spec.py index 991ae7fcb89..5514ed95c81 100644 --- a/core/dbt/graph/selector_spec.py +++ b/core/dbt/graph/selector_spec.py @@ -80,7 +80,7 @@ def __post_init__(self): def default_method(cls, value: str) -> MethodName: if _probably_path(value): return MethodName.Path - elif value.lower().endswith((".sql", ".py", ".csv")): + elif value.lower().endswith((".sql", ".py", ".csv", ".prql")): return MethodName.File else: return MethodName.FQN diff --git a/core/dbt/node_types.py b/core/dbt/node_types.py index a6fa5ff4f84..eec59decc5c 100644 --- a/core/dbt/node_types.py +++ b/core/dbt/node_types.py @@ -13,6 +13,7 @@ class NodeType(StrEnum): # TODO: rm? RPCCall = "rpc" SqlOperation = "sql operation" + PrqlOperation = "prql operation" Documentation = "docs block" Source = "source" Macro = "macro" @@ -31,6 +32,7 @@ def executable(cls) -> List["NodeType"]: cls.Documentation, cls.RPCCall, cls.SqlOperation, + cls.PrqlOperation, ] @classmethod @@ -68,3 +70,4 @@ class RunHookType(StrEnum): class ModelLanguage(StrEnum): python = "python" sql = "sql" + prql = "prql" diff --git a/core/dbt/parser/_dbt_prql.py b/core/dbt/parser/_dbt_prql.py new file mode 100644 index 00000000000..8e90372b045 --- /dev/null +++ b/core/dbt/parser/_dbt_prql.py @@ -0,0 +1,115 @@ +""" +This will be in the `dbt-prql` package, but including here during inital code review, so +we can test it without coordinating dependencies. +""" + +from __future__ import annotations + +import logging +import re +from collections import defaultdict + +import prql_python + +logger = logging.getLogger(__name__) + +word_regex = r"[\w\.\-_]+" +references_regex = rf"\bdbt `?(\w+)\.({word_regex})\.({word_regex})`?" + +# dict of ref_type (e.g. source, ref) -> (dict of (package, table) -> literal) +references_type = dict[str, dict[tuple[str, str], str]] + + +def compile(prql: str, references: references_type): + """ + >>> print(compile( + ... "from (dbt source.salesforce.in_process) | join (dbt ref.foo.bar) [id]", + ... references=dict( + ... source={('salesforce', 'in_process'): 'salesforce_schema.in_process_tbl'}, + ... ref={('foo', 'bar'): 'foo_schema.bar_tbl'} + ... ) + ... )) + SELECT + "{{ source('salesforce', 'in_process') }}".*, + "{{ ref('foo', 'bar') }}".*, + id + FROM + {{ source('salesforce', 'in_process') }} + JOIN {{ ref('foo', 'bar') }} USING(id) + """ + + # references = list_references(prql) + prql = _hack_sentinels_of_model(prql) + + sql = prql_python.to_sql(prql) + + # The intention was to replace the sentinels with table names. (Which would be done + # by dbt-prql). But we now just pass the SQL off and let dbt handle it; something + # expedient but not elegant. + # sql = replace_faux_jinja(sql, references) + return sql + + +def list_references(prql): + """ + List all references (e.g. sources / refs) in a given block. + + We need to decide: + + — What should prqlc return given `dbt source.foo.bar`, so dbt-prql can find the + references? + — Should it just fill in something that looks like jinja for expediancy? (We + don't support jinja though) + + >>> references = list_references("from (dbt source.salesforce.in_process) | join (dbt ref.foo.bar)") + >>> dict(references) + {'source': [('salesforce', 'in_process')], 'ref': [('foo', 'bar')]} + """ + + out = defaultdict(list) + for t, package, model in _hack_references_of_prql_query(prql): + out[t] += [(package, model)] + + return out + + +def _hack_references_of_prql_query(prql) -> list[tuple[str, str, str]]: + """ + List the references in a prql query. + + This would be implemented by prqlc. + + >>> _hack_references_of_prql_query("from (dbt source.salesforce.in_process) | join (dbt ref.foo.bar)") + [('source', 'salesforce', 'in_process'), ('ref', 'foo', 'bar')] + """ + + return re.findall(references_regex, prql) + + +def _hack_sentinels_of_model(prql: str) -> str: + """ + Replace the dbt calls with a jinja-like sentinel. + + This will be done by prqlc... + + >>> _hack_sentinels_of_model("from (dbt source.salesforce.in_process) | join (dbt ref.foo.bar) [id]") + "from (`{{ source('salesforce', 'in_process') }}`) | join (`{{ ref('foo', 'bar') }}`) [id]" + """ + return re.sub(references_regex, r"`{{ \1('\2', '\3') }}`", prql) + + +def replace_faux_jinja(sql: str, references: references_type): + """ + >>> print(replace_faux_jinja( + ... "SELECT * FROM {{ source('salesforce', 'in_process') }}", + ... references=dict(source={('salesforce', 'in_process'): 'salesforce_schema.in_process_tbl'}) + ... )) + SELECT * FROM salesforce_schema.in_process_tbl + + """ + for ref_type, lookups in references.items(): + for (package, table), literal in lookups.items(): + prql_new = sql.replace(f"{{{{ {ref_type}('{package}', '{table}') }}}}", literal) + sql = prql_new + + return sql diff --git a/core/dbt/parser/base.py b/core/dbt/parser/base.py index 2786a7c5744..b82ff1c3158 100644 --- a/core/dbt/parser/base.py +++ b/core/dbt/parser/base.py @@ -192,6 +192,8 @@ def _create_parsetime_node( name = block.name if block.path.relative_path.endswith(".py"): language = ModelLanguage.python + elif block.path.relative_path.endswith(".prql"): + language = ModelLanguage.prql else: # this is not ideal but we have a lot of tests to adjust if don't do it language = ModelLanguage.sql @@ -225,6 +227,7 @@ def _create_parsetime_node( path=path, original_file_path=block.path.original_file_path, raw_code=block.contents, + # language=language, ) raise ParsingException(msg, node=node) @@ -414,6 +417,22 @@ def parse_file(self, file_block: FileBlock) -> None: self.parse_node(file_block) +# TOOD: Currently `ModelParser` inherits from `SimpleSQLParser`, which inherits from +# SQLParser. So possibly `ModelParser` should instead be `SQLModelParser`, and in +# `ManifestLoader.load`, we should resolve which to use? (I think that's how it works; +# though then why all the inheritance and generics if every model is parsed with `ModelParser`?) +# class PRQLParser( +# ConfiguredParser[FileBlock, IntermediateNode, FinalNode], Generic[IntermediateNode, FinalNode] +# ): +# The full mro is: +# dbt.parser.models.ModelParser, +# dbt.parser.base.SimpleSQLParser, +# dbt.parser.base.SQLParser, +# dbt.parser.base.ConfiguredParser, +# dbt.parser.base.Parser, +# dbt.parser.base.BaseParser, + + class SimpleSQLParser(SQLParser[FinalNode, FinalNode]): def transform(self, node): return node diff --git a/core/dbt/parser/language_provider.py b/core/dbt/parser/language_provider.py new file mode 100644 index 00000000000..55b41622229 --- /dev/null +++ b/core/dbt/parser/language_provider.py @@ -0,0 +1,55 @@ +from __future__ import annotations + + +try: + import dbt_prql +except ImportError: + dbt_prql = None + +# dict of ref_type (e.g. source, ref) -> (dict of (package, table) -> literal) +references_type = dict[str, dict[tuple[str, str], str]] + + +class LanguageProvider: + """ + A LanguageProvider is a class that can parse a given language. + + Currently implemented only for PRQL, but we could extend this to other languages (in + the medium term) + + TODO: See notes in `ModelParser.render_update`; the current implementation has some + missmatches. + """ + + # def compile(self, code: str) -> ParsedNode: + def compile(self, code: str) -> str: + """ + Compile a given block into a ParsedNode. + """ + raise NotImplementedError("compile") + + def list_references(self, code: str) -> references_type: + """ + List all references (e.g. sources / refs) in a given block. + """ + raise NotImplementedError("list_references") + + +class PrqlProvider: + def __init__(self) -> None: + # TODO: Uncomment when dbt-prql is released + # if not dbt_prql: + # raise ImportError( + # "dbt_prql is required and not found; try running `pip install dbt_prql`" + # ) + pass + + def compile(self, code: str, references: dict) -> str: + from . import _dbt_prql as dbt_prql + + return dbt_prql.compile(code, references=references) + + def list_references(self, code: str) -> references_type: + from . import _dbt_prql as dbt_prql + + return dbt_prql.list_references(code) diff --git a/core/dbt/parser/models.py b/core/dbt/parser/models.py index 9a7b4974aaa..e7768e5cf10 100644 --- a/core/dbt/parser/models.py +++ b/core/dbt/parser/models.py @@ -31,6 +31,7 @@ from dbt.dataclass_schema import ValidationError from dbt.exceptions import ParsingException, validator_error_message, UndefinedMacroException +from .language_provider import PrqlProvider dbt_function_key_words = set(["ref", "source", "config", "get"]) dbt_function_full_names = set(["dbt.ref", "dbt.source", "dbt.config", "dbt.config.get"]) @@ -233,6 +234,26 @@ def render_update(self, node: ParsedModelNode, config: ContextConfig) -> None: raise ParsingException(msg, node=node) from exc return + if node.language == ModelLanguage.prql: + provider = PrqlProvider() + try: + context = self._context_for(node, config) + references = provider.list_references(node.raw_code) + sql = provider.compile(node.raw_code, references) + # This is currently kinda a hack; I'm not happy about it. + # + # The original intention was to use the result of `.compile` and pass + # that to the Model. But currently the parsing seems extremely coupled + # to the `ModelParser` object. So possibly we should instead inherit + # from that, and there should be one of those for each language? + node.raw_code = sql + node.language = ModelLanguage.sql + + except ValidationError as exc: + # we got a ValidationError - probably bad types in config() + msg = validator_error_message(exc) + raise ParsingException(msg, node=node) from exc + elif not flags.STATIC_PARSER: # jinja rendering super().render_update(node, config) diff --git a/core/dbt/parser/read_files.py b/core/dbt/parser/read_files.py index ccb6b1b0790..2297405265a 100644 --- a/core/dbt/parser/read_files.py +++ b/core/dbt/parser/read_files.py @@ -175,7 +175,7 @@ def read_files(project, files, parser_files, saved_files): project, files, project.model_paths, - [".sql", ".py"], + [".sql", ".py", ".prql"], ParseFileType.Model, saved_files, dbt_ignore_spec, diff --git a/core/dbt/parser/schemas.py b/core/dbt/parser/schemas.py index 4a6202ff0b8..c8010de7cb4 100644 --- a/core/dbt/parser/schemas.py +++ b/core/dbt/parser/schemas.py @@ -272,6 +272,7 @@ def get_hashable_md(data: Union[str, int, float, List, Dict]) -> Union[str, List path=path, original_file_path=target.original_file_path, raw_code=raw_code, + # language="sql", ) raise ParsingException(msg, node=node) from exc diff --git a/test/unit/test_graph.py b/test/unit/test_graph.py index 5534fe21f19..6efb14ceff1 100644 --- a/test/unit/test_graph.py +++ b/test/unit/test_graph.py @@ -140,7 +140,7 @@ def get_compiler(self, project): return dbt.compilation.Compiler(project) def use_models(self, models): - for k, v in models.items(): + for k, (source, lang) in models.items(): path = FilePath( searched_path='models', project_root=os.path.normcase(os.getcwd()), @@ -148,8 +148,8 @@ def use_models(self, models): modification_time=0.0, ) # FileHash can't be empty or 'search_key' will be None - source_file = SourceFile(path=path, checksum=FileHash.from_contents('abc')) - source_file.contents = v + source_file = SourceFile(path=path, checksum=FileHash.from_contents('abc'), language=lang) + source_file.contents = source self.mock_models.append(source_file) def load_manifest(self, config): @@ -162,7 +162,7 @@ def load_manifest(self, config): def test__single_model(self): self.use_models({ - 'model_one': 'select * from events', + 'model_one':( 'select * from events', 'sql'), }) config = self.get_config() @@ -181,8 +181,8 @@ def test__single_model(self): def test__two_models_simple_ref(self): self.use_models({ - 'model_one': 'select * from events', - 'model_two': "select * from {{ref('model_one')}}", + 'model_one':( 'select * from events', 'sql'), + 'model_two':( "select * from {{ref('model_one')}}", 'sql'), }) config = self.get_config() @@ -205,10 +205,10 @@ def test__two_models_simple_ref(self): def test__model_materializations(self): self.use_models({ - 'model_one': 'select * from events', - 'model_two': "select * from {{ref('model_one')}}", - 'model_three': "select * from events", - 'model_four': "select * from events", + 'model_one':( 'select * from events', 'sql'), + 'model_two':( "select * from {{ref('model_one')}}", 'sql'), + 'model_three':( 'select * from events', 'sql'), + 'model_four':( 'select * from events', 'sql'), }) cfg = { @@ -241,7 +241,7 @@ def test__model_materializations(self): def test__model_incremental(self): self.use_models({ - 'model_one': 'select * from events' + 'model_one':( 'select * from events', 'sql'), }) cfg = { @@ -269,15 +269,15 @@ def test__model_incremental(self): def test__dependency_list(self): self.use_models({ - 'model_1': 'select * from events', - 'model_2': 'select * from {{ ref("model_1") }}', - 'model_3': ''' + 'model_1':( 'select * from events', 'sql'), + 'model_2':( 'select * from {{ ref("model_1") }}', 'sql'), + 'model_3': (''' select * from {{ ref("model_1") }} union all select * from {{ ref("model_2") }} - ''', - 'model_4': 'select * from {{ ref("model_3") }}' - }) + ''', "sql"), + 'model_4':( 'select * from {{ ref("model_3") }}', 'sql'), +}) config = self.get_config() manifest = self.load_manifest(config) diff --git a/test/unit/test_graph_selector_methods.py b/test/unit/test_graph_selector_methods.py index e32267e2d6f..cf817ca404f 100644 --- a/test/unit/test_graph_selector_methods.py +++ b/test/unit/test_graph_selector_methods.py @@ -46,7 +46,8 @@ from .utils import replace_config -def make_model(pkg, name, sql, refs=None, sources=None, tags=None, path=None, alias=None, config_kwargs=None, fqn_extras=None, depends_on_macros=None): +# TODO: possibly change `sql` arg to `code` +def make_model(pkg, name, sql, refs=None, sources=None, tags=None, path=None, alias=None, config_kwargs=None, fqn_extras=None, depends_on_macros=None, language='sql'): if refs is None: refs = [] if sources is None: @@ -54,7 +55,7 @@ def make_model(pkg, name, sql, refs=None, sources=None, tags=None, path=None, al if tags is None: tags = [] if path is None: - path = f'{name}.sql' + path = f'{name}.{language}' if alias is None: alias = name if config_kwargs is None: @@ -78,7 +79,7 @@ def make_model(pkg, name, sql, refs=None, sources=None, tags=None, path=None, al depends_on_nodes.append(src.unique_id) return ParsedModelNode( - language='sql', + language=language, raw_code=sql, database='dbt', schema='dbt_schema', @@ -486,6 +487,18 @@ def table_model(ephemeral_model): path='subdirectory/table_model.sql' ) +@pytest.fixture +def table_model_prql(seed): + return make_model( + 'pkg', + 'table_model_prql', + 'from (dbt source employees)', + config_kwargs={'materialized': 'table'}, + refs=[seed], + tags=[], + path='subdirectory/table_model.prql' + ) + @pytest.fixture def table_model_py(seed): return make_model( @@ -627,11 +640,11 @@ def namespaced_union_model(seed, ext_source): ) @pytest.fixture -def manifest(seed, source, ephemeral_model, view_model, table_model, table_model_py, table_model_csv, ext_source, ext_model, union_model, ext_source_2, +def manifest(seed, source, ephemeral_model, view_model, table_model, table_model_py, table_model_csv, table_model_prql, ext_source, ext_model, union_model, ext_source_2, ext_source_other, ext_source_other_2, table_id_unique, table_id_not_null, view_id_unique, ext_source_id_unique, view_test_nothing, namespaced_seed, namespace_model, namespaced_union_model, macro_test_unique, macro_default_test_unique, macro_test_not_null, macro_default_test_not_null): - nodes = [seed, ephemeral_model, view_model, table_model, table_model_py, table_model_csv, union_model, ext_model, + nodes = [seed, ephemeral_model, view_model, table_model, table_model_py, table_model_csv, table_model_prql, union_model, ext_model, table_id_unique, table_id_not_null, view_id_unique, ext_source_id_unique, view_test_nothing, namespaced_seed, namespace_model, namespaced_union_model] sources = [source, ext_source, ext_source_2, @@ -669,7 +682,7 @@ def test_select_fqn(manifest): assert not search_manifest_using_method(manifest, method, 'ext.unions') # sources don't show up, because selection pretends they have no FQN. Should it? assert search_manifest_using_method(manifest, method, 'pkg') == { - 'union_model', 'table_model', 'table_model_py', 'table_model_csv', 'view_model', 'ephemeral_model', 'seed', + 'union_model', 'table_model', 'table_model_py', 'table_model_csv', 'table_model_prql', 'view_model', 'ephemeral_model', 'seed', 'mynamespace.union_model', 'mynamespace.ephemeral_model', 'mynamespace.seed'} assert search_manifest_using_method( manifest, method, 'ext') == {'ext_model'} @@ -752,6 +765,8 @@ def test_select_file(manifest): manifest, method, 'table_model.py') == {'table_model_py'} assert search_manifest_using_method( manifest, method, 'table_model.csv') == {'table_model_csv'} + assert search_manifest_using_method( + manifest, method, 'table_model.prql') == {'table_model_prql'} assert search_manifest_using_method( manifest, method, 'union_model.sql') == {'union_model', 'mynamespace.union_model'} assert not search_manifest_using_method( @@ -766,7 +781,7 @@ def test_select_package(manifest): assert isinstance(method, PackageSelectorMethod) assert method.arguments == [] - assert search_manifest_using_method(manifest, method, 'pkg') == {'union_model', 'table_model', 'table_model_py', 'table_model_csv', 'view_model', 'ephemeral_model', + assert search_manifest_using_method(manifest, method, 'pkg') == {'union_model', 'table_model', 'table_model_py', 'table_model_csv', 'table_model_prql', 'view_model', 'ephemeral_model', 'seed', 'raw.seed', 'unique_table_model_id', 'not_null_table_model_id', 'unique_view_model_id', 'view_test_nothing', 'mynamespace.seed', 'mynamespace.ephemeral_model', 'mynamespace.union_model', } @@ -785,7 +800,7 @@ def test_select_config_materialized(manifest): assert search_manifest_using_method(manifest, method, 'view') == { 'view_model', 'ext_model'} assert search_manifest_using_method(manifest, method, 'table') == { - 'table_model', 'table_model_py', 'table_model_csv', 'union_model', 'mynamespace.union_model'} + 'table_model', 'table_model_py', 'table_model_csv', 'table_model_prql', 'union_model', 'mynamespace.union_model'} def test_select_config_meta(manifest): methods = MethodManager(manifest, None) diff --git a/test/unit/test_node_types.py b/test/unit/test_node_types.py index fcfb115b9b9..3ba3f7a903e 100644 --- a/test/unit/test_node_types.py +++ b/test/unit/test_node_types.py @@ -10,6 +10,7 @@ NodeType.Seed: "seeds", NodeType.RPCCall: "rpcs", NodeType.SqlOperation: "sql operations", + NodeType.PrqlOperation: "prql operations", NodeType.Documentation: "docs blocks", NodeType.Source: "sources", NodeType.Macro: "macros", diff --git a/test/unit/test_parser.py b/test/unit/test_parser.py index e88bfbc845b..fdc684c7944 100644 --- a/test/unit/test_parser.py +++ b/test/unit/test_parser.py @@ -40,6 +40,7 @@ import itertools from .utils import config_from_parts_or_dicts, normalize, generate_name_macros, MockNode, MockSource, MockDocumentation +import dataclasses def get_abs_os_path(unix_path): return normalize(os.path.abspath(unix_path)) @@ -713,6 +714,54 @@ def test_parse_error(self): with self.assertRaises(CompilationException): self.parser.parse_file(block) + def test_parse_prql_file(self): + prql_code = """ + from (dbt source.salesforce.in_process) + join (dbt ref.foo.bar) [id] + filter salary > 100 + """ + block = self.file_block_for(prql_code, 'nested/prql_model.prql') + self.parser.manifest.files[block.file.file_id] = block.file + self.parser.parse_file(block) + self.assert_has_manifest_lengths(self.parser.manifest, nodes=1) + node = list(self.parser.manifest.nodes.values())[0] + compiled_sql = """ +SELECT + "{{ source('salesforce', 'in_process') }}".*, + "{{ ref('foo', 'bar') }}".*, + id +FROM + {{ source('salesforce', 'in_process') }} + JOIN {{ ref('foo', 'bar') }} USING(id) +WHERE + salary > 100 + """.strip() + expected = ParsedModelNode( + alias='prql_model', + name='prql_model', + database='test', + schema='analytics', + resource_type=NodeType.Model, + unique_id='model.snowplow.prql_model', + fqn=['snowplow', 'nested', 'prql_model'], + package_name='snowplow', + original_file_path=normalize('models/nested/prql_model.prql'), + root_path=get_abs_os_path('./dbt_packages/snowplow'), + config=NodeConfig(materialized='view'), + path=normalize('nested/prql_model.prql'), + language='sql', # It's compiled into SQL + raw_code=compiled_sql, + checksum=block.file.checksum, + unrendered_config={'packages': set()}, + config_call_dict={}, + refs=[["foo", "bar"], ["foo", "bar"]], + sources=[['salesforce', 'in_process']], + ) + assertEqualNodes(node, expected) + file_id = 'snowplow://' + normalize('models/nested/prql_model.prql') + self.assertIn(file_id, self.parser.manifest.files) + self.assertEqual(self.parser.manifest.files[file_id].nodes, ['model.snowplow.prql_model']) + def test_parse_ref_with_non_string(self): py_code = """ def model(dbt, session): From fa3f17200ff66e1f2d48a0b80bb271a15aa5729d Mon Sep 17 00:00:00 2001 From: Maximilian Roos Date: Sun, 2 Oct 2022 11:10:38 -0700 Subject: [PATCH 2/9] Add a mock return from prql_python --- core/dbt/parser/_dbt_prql.py | 22 +++++++++++++++++++++- 1 file changed, 21 insertions(+), 1 deletion(-) diff --git a/core/dbt/parser/_dbt_prql.py b/core/dbt/parser/_dbt_prql.py index 8e90372b045..2826a40f038 100644 --- a/core/dbt/parser/_dbt_prql.py +++ b/core/dbt/parser/_dbt_prql.py @@ -9,7 +9,27 @@ import re from collections import defaultdict -import prql_python +try: + import prql_python +except ModuleNotFoundError: + # Always return the same SQL, mocking the prqlc output, so we can test this without + # configuring dependencies. (Obv fix as we expand the tests, way before we merge.) + class prql_python: # type: ignore + @staticmethod + def to_sql(prql): + compiled_sql = """ +SELECT +"{{ source('salesforce', 'in_process') }}".*, +"{{ ref('foo', 'bar') }}".*, +id +FROM +{{ source('salesforce', 'in_process') }} +JOIN {{ ref('foo', 'bar') }} USING(id) +WHERE +salary > 100 + """.strip() + return compiled_sql + logger = logging.getLogger(__name__) From 5a8fd1e90d3820ecafe394d8c6577e0146b2d60e Mon Sep 17 00:00:00 2001 From: Maximilian Roos Date: Sun, 2 Oct 2022 12:58:53 -0700 Subject: [PATCH 3/9] Ignore types in the import hacks (tests still fail b/c typing_extensions is not installed) --- core/dbt/parser/_dbt_prql.py | 7 ++++--- core/dbt/parser/language_provider.py | 2 +- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/core/dbt/parser/_dbt_prql.py b/core/dbt/parser/_dbt_prql.py index 2826a40f038..529e80a51bb 100644 --- a/core/dbt/parser/_dbt_prql.py +++ b/core/dbt/parser/_dbt_prql.py @@ -10,10 +10,11 @@ from collections import defaultdict try: - import prql_python + import prql_python # type: ignore except ModuleNotFoundError: - # Always return the same SQL, mocking the prqlc output, so we can test this without - # configuring dependencies. (Obv fix as we expand the tests, way before we merge.) + # Always return the same SQL, mocking the prqlc output for a single case which we + # currently use in tests, so we can test this without configuring dependencies. (Obv + # fix as we expand the tests, way before we merge.) class prql_python: # type: ignore @staticmethod def to_sql(prql): diff --git a/core/dbt/parser/language_provider.py b/core/dbt/parser/language_provider.py index 55b41622229..ad2cd7223e0 100644 --- a/core/dbt/parser/language_provider.py +++ b/core/dbt/parser/language_provider.py @@ -2,7 +2,7 @@ try: - import dbt_prql + import dbt_prql # type: ignore except ImportError: dbt_prql = None From ebff2ceb726f1978fed669209b250e68fb3ed13d Mon Sep 17 00:00:00 2001 From: Maximilian Roos Date: Sun, 2 Oct 2022 13:13:12 -0700 Subject: [PATCH 4/9] Revert to importing builtins from typing --- core/dbt/parser/_dbt_prql.py | 7 ++++--- core/dbt/parser/language_provider.py | 8 +++++++- 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/core/dbt/parser/_dbt_prql.py b/core/dbt/parser/_dbt_prql.py index 529e80a51bb..b914347a827 100644 --- a/core/dbt/parser/_dbt_prql.py +++ b/core/dbt/parser/_dbt_prql.py @@ -8,6 +8,10 @@ import logging import re from collections import defaultdict +import typing + +if typing.TYPE_CHECKING: + from dbt.parser.language_provider import references_type try: import prql_python # type: ignore @@ -37,9 +41,6 @@ def to_sql(prql): word_regex = r"[\w\.\-_]+" references_regex = rf"\bdbt `?(\w+)\.({word_regex})\.({word_regex})`?" -# dict of ref_type (e.g. source, ref) -> (dict of (package, table) -> literal) -references_type = dict[str, dict[tuple[str, str], str]] - def compile(prql: str, references: references_type): """ diff --git a/core/dbt/parser/language_provider.py b/core/dbt/parser/language_provider.py index ad2cd7223e0..78ea8d26672 100644 --- a/core/dbt/parser/language_provider.py +++ b/core/dbt/parser/language_provider.py @@ -7,7 +7,13 @@ dbt_prql = None # dict of ref_type (e.g. source, ref) -> (dict of (package, table) -> literal) -references_type = dict[str, dict[tuple[str, str], str]] +# references_type = dict[str, dict[tuple[str, str], str]] + +# I can't get the above to work on CI; I had thought that it was fine with +# from __future__ import annotations, but it seems not. So, we'll just use Dict. +from typing import Dict, Tuple + +references_type = Dict[str, Dict[Tuple[str, str], str]] class LanguageProvider: From c9572c310642940238ad3bf3e4747b1fc1c81f71 Mon Sep 17 00:00:00 2001 From: Maximilian Roos Date: Sun, 2 Oct 2022 16:53:10 -0700 Subject: [PATCH 5/9] Always use the mock method to align the snapshot tests --- core/dbt/parser/_dbt_prql.py | 23 +++++++++++------------ core/dbt/parser/language_provider.py | 8 ++++---- test/unit/test_parser.py | 12 ++++++------ 3 files changed, 21 insertions(+), 22 deletions(-) diff --git a/core/dbt/parser/_dbt_prql.py b/core/dbt/parser/_dbt_prql.py index b914347a827..65e4c2b3c9e 100644 --- a/core/dbt/parser/_dbt_prql.py +++ b/core/dbt/parser/_dbt_prql.py @@ -13,16 +13,15 @@ if typing.TYPE_CHECKING: from dbt.parser.language_provider import references_type -try: - import prql_python # type: ignore -except ModuleNotFoundError: - # Always return the same SQL, mocking the prqlc output for a single case which we - # currently use in tests, so we can test this without configuring dependencies. (Obv - # fix as we expand the tests, way before we merge.) - class prql_python: # type: ignore - @staticmethod - def to_sql(prql): - compiled_sql = """ +# import prql_python + +# Always return the same SQL, mocking the prqlc output for a single case which we +# currently use in tests, so we can test this without configuring dependencies. (Obv +# fix as we expand the tests, way before we merge.) +class prql_python: # type: ignore + @staticmethod + def to_sql(prql): + compiled_sql = """ SELECT "{{ source('salesforce', 'in_process') }}".*, "{{ ref('foo', 'bar') }}".*, @@ -32,8 +31,8 @@ def to_sql(prql): JOIN {{ ref('foo', 'bar') }} USING(id) WHERE salary > 100 - """.strip() - return compiled_sql + """.strip() + return compiled_sql logger = logging.getLogger(__name__) diff --git a/core/dbt/parser/language_provider.py b/core/dbt/parser/language_provider.py index 78ea8d26672..6b5849c9329 100644 --- a/core/dbt/parser/language_provider.py +++ b/core/dbt/parser/language_provider.py @@ -1,10 +1,10 @@ from __future__ import annotations -try: - import dbt_prql # type: ignore -except ImportError: - dbt_prql = None +# try: +# import dbt_prql # type: ignore +# except ImportError: +# dbt_prql = None # dict of ref_type (e.g. source, ref) -> (dict of (package, table) -> literal) # references_type = dict[str, dict[tuple[str, str], str]] diff --git a/test/unit/test_parser.py b/test/unit/test_parser.py index fdc684c7944..aceb74ee9cb 100644 --- a/test/unit/test_parser.py +++ b/test/unit/test_parser.py @@ -727,14 +727,14 @@ def test_parse_prql_file(self): node = list(self.parser.manifest.nodes.values())[0] compiled_sql = """ SELECT - "{{ source('salesforce', 'in_process') }}".*, - "{{ ref('foo', 'bar') }}".*, - id +"{{ source('salesforce', 'in_process') }}".*, +"{{ ref('foo', 'bar') }}".*, +id FROM - {{ source('salesforce', 'in_process') }} - JOIN {{ ref('foo', 'bar') }} USING(id) +{{ source('salesforce', 'in_process') }} +JOIN {{ ref('foo', 'bar') }} USING(id) WHERE - salary > 100 +salary > 100 """.strip() expected = ParsedModelNode( alias='prql_model', From 8eece383eab104dfde17f2b0d7a864913dafaa1a Mon Sep 17 00:00:00 2001 From: Maximilian Roos Date: Sun, 2 Oct 2022 17:27:57 -0700 Subject: [PATCH 6/9] flake --- core/dbt/parser/_dbt_prql.py | 1 + 1 file changed, 1 insertion(+) diff --git a/core/dbt/parser/_dbt_prql.py b/core/dbt/parser/_dbt_prql.py index 65e4c2b3c9e..52c72cac390 100644 --- a/core/dbt/parser/_dbt_prql.py +++ b/core/dbt/parser/_dbt_prql.py @@ -15,6 +15,7 @@ # import prql_python + # Always return the same SQL, mocking the prqlc output for a single case which we # currently use in tests, so we can test this without configuring dependencies. (Obv # fix as we expand the tests, way before we merge.) From 86eb68f40d33929c872eb6bde6f001cd3b1f7488 Mon Sep 17 00:00:00 2001 From: Maximilian Roos Date: Sun, 2 Oct 2022 19:08:10 -0700 Subject: [PATCH 7/9] Add test to test_graph.py --- core/dbt/parser/_dbt_prql.py | 46 +++++++++++++++++++++++++++++------- test/unit/test_graph.py | 24 ++++++++++++++++++- test/unit/test_parser.py | 8 +++---- 3 files changed, 65 insertions(+), 13 deletions(-) diff --git a/core/dbt/parser/_dbt_prql.py b/core/dbt/parser/_dbt_prql.py index 52c72cac390..6cd762f4778 100644 --- a/core/dbt/parser/_dbt_prql.py +++ b/core/dbt/parser/_dbt_prql.py @@ -13,16 +13,37 @@ if typing.TYPE_CHECKING: from dbt.parser.language_provider import references_type -# import prql_python - -# Always return the same SQL, mocking the prqlc output for a single case which we -# currently use in tests, so we can test this without configuring dependencies. (Obv -# fix as we expand the tests, way before we merge.) +# import prql_python +# This mocks the prqlc output for two cases which we currently use in tests, so we can +# test this without configuring dependencies. (Obv fix as we expand the tests, way +# before we merge.) class prql_python: # type: ignore @staticmethod - def to_sql(prql): - compiled_sql = """ + def to_sql(prql) -> str: + + query_1 = "from employees" + + query_1_compiled = """ +SELECT + employees.* +FROM + employees + """.strip() + + query_2 = """ +from (dbt source.salesforce.in_process) +join (dbt ref.foo.bar) [id] +filter salary > 100 + """.strip() + + query_2_refs_replaced = """ +from (`{{ source('salesforce', 'in_process') }}`) +join (`{{ ref('foo', 'bar') }}`) [id] +filter salary > 100 + """.strip() + + query_2_compiled = """ SELECT "{{ source('salesforce', 'in_process') }}".*, "{{ ref('foo', 'bar') }}".*, @@ -33,7 +54,16 @@ def to_sql(prql): WHERE salary > 100 """.strip() - return compiled_sql + + lookup = dict( + { + query_1: query_1_compiled, + query_2: query_2_compiled, + query_2_refs_replaced: query_2_compiled, + } + ) + + return lookup[prql] logger = logging.getLogger(__name__) diff --git a/test/unit/test_graph.py b/test/unit/test_graph.py index 6efb14ceff1..959b1ceaa4e 100644 --- a/test/unit/test_graph.py +++ b/test/unit/test_graph.py @@ -60,6 +60,9 @@ def setUp(self): # Create file filesystem searcher self.filesystem_search = patch('dbt.parser.read_files.filesystem_search') def mock_filesystem_search(project, relative_dirs, extension, ignore_spec): + # Adding in `and "prql" not in extension` will cause a bunch of tests to + # fail; need to understand more on how these are constructed to debug. + # Possibly `sql not in extension` is a way of having it only run once. if 'sql' not in extension: return [] if 'models' not in relative_dirs: @@ -144,7 +147,7 @@ def use_models(self, models): path = FilePath( searched_path='models', project_root=os.path.normcase(os.getcwd()), - relative_path='{}.sql'.format(k), + relative_path=f'{k}.{lang}', modification_time=0.0, ) # FileHash can't be empty or 'search_key' will be None @@ -328,3 +331,22 @@ def test__partial_parse(self): manifest.metadata.dbt_version = '99999.99.99' is_partial_parsable, _ = loader.is_partial_parsable(manifest) self.assertFalse(is_partial_parsable) + + def test_models_prql(self): + self.use_models({ + 'model_prql':( 'from employees', 'prql'), + }) + + config = self.get_config() + manifest = self.load_manifest(config) + + compiler = self.get_compiler(config) + linker = compiler.compile(manifest) + + self.assertEqual( + list(linker.nodes()), + ['model.test_models_compile.model_prql']) + + self.assertEqual( + list(linker.edges()), + []) diff --git a/test/unit/test_parser.py b/test/unit/test_parser.py index aceb74ee9cb..b159acca86d 100644 --- a/test/unit/test_parser.py +++ b/test/unit/test_parser.py @@ -716,10 +716,10 @@ def test_parse_error(self): def test_parse_prql_file(self): prql_code = """ - from (dbt source.salesforce.in_process) - join (dbt ref.foo.bar) [id] - filter salary > 100 - """ +from (dbt source.salesforce.in_process) +join (dbt ref.foo.bar) [id] +filter salary > 100 + """.strip() block = self.file_block_for(prql_code, 'nested/prql_model.prql') self.parser.manifest.files[block.file.file_id] = block.file self.parser.parse_file(block) From bc8b65095eb41d0df84150029c843a4204ea29f7 Mon Sep 17 00:00:00 2001 From: Maximilian Roos Date: Sun, 2 Oct 2022 23:04:03 -0700 Subject: [PATCH 8/9] Add language on error nodes --- core/dbt/parser/base.py | 4 ++-- core/dbt/parser/schemas.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/core/dbt/parser/base.py b/core/dbt/parser/base.py index b82ff1c3158..9b429aeb3c5 100644 --- a/core/dbt/parser/base.py +++ b/core/dbt/parser/base.py @@ -157,7 +157,7 @@ def _mangle_hooks(self, config): config[key] = [hooks.get_hook_dict(h) for h in config[key]] def _create_error_node( - self, name: str, path: str, original_file_path: str, raw_code: str, language: str = "sql" + self, name: str, path: str, original_file_path: str, raw_code: str, language: str ) -> UnparsedNode: """If we hit an error before we've actually parsed a node, provide some level of useful information by attaching this to the exception. @@ -227,7 +227,7 @@ def _create_parsetime_node( path=path, original_file_path=block.path.original_file_path, raw_code=block.contents, - # language=language, + language=language, ) raise ParsingException(msg, node=node) diff --git a/core/dbt/parser/schemas.py b/core/dbt/parser/schemas.py index c8010de7cb4..2e0b9186d44 100644 --- a/core/dbt/parser/schemas.py +++ b/core/dbt/parser/schemas.py @@ -272,7 +272,7 @@ def get_hashable_md(data: Union[str, int, float, List, Dict]) -> Union[str, List path=path, original_file_path=target.original_file_path, raw_code=raw_code, - # language="sql", + language="sql", ) raise ParsingException(msg, node=node) from exc From 472940423cb9817dc390ac7274eb03209b2576df Mon Sep 17 00:00:00 2001 From: Maximilian Roos Date: Wed, 5 Oct 2022 18:35:06 -0700 Subject: [PATCH 9/9] Remove unused `PrqlNode` & friends --- core/dbt/contracts/graph/compiled.py | 7 ------- core/dbt/contracts/graph/parsed.py | 5 ----- core/dbt/node_types.py | 2 -- test/unit/test_node_types.py | 1 - 4 files changed, 15 deletions(-) diff --git a/core/dbt/contracts/graph/compiled.py b/core/dbt/contracts/graph/compiled.py index e7b48ae9268..118d104f537 100644 --- a/core/dbt/contracts/graph/compiled.py +++ b/core/dbt/contracts/graph/compiled.py @@ -10,7 +10,6 @@ ParsedResource, ParsedRPCNode, ParsedSqlNode, - ParsedPrqlNode, ParsedGenericTestNode, ParsedSeedNode, ParsedSnapshotNode, @@ -93,11 +92,6 @@ class CompiledSqlNode(CompiledNode): resource_type: NodeType = field(metadata={"restrict": [NodeType.SqlOperation]}) -@dataclass -class CompiledPrqlNode(CompiledNode): - resource_type: NodeType = field(metadata={"restrict": [NodeType.PrqlOperation]}) - - @dataclass class CompiledSeedNode(CompiledNode): # keep this in sync with ParsedSeedNode! @@ -152,7 +146,6 @@ def same_contents(self, other) -> bool: CompiledHookNode: ParsedHookNode, CompiledRPCNode: ParsedRPCNode, CompiledSqlNode: ParsedSqlNode, - CompiledPrqlNode: ParsedPrqlNode, CompiledSeedNode: ParsedSeedNode, CompiledSnapshotNode: ParsedSnapshotNode, CompiledSingularTestNode: ParsedSingularTestNode, diff --git a/core/dbt/contracts/graph/parsed.py b/core/dbt/contracts/graph/parsed.py index 19d6eef5af6..8fb6c6235aa 100644 --- a/core/dbt/contracts/graph/parsed.py +++ b/core/dbt/contracts/graph/parsed.py @@ -363,11 +363,6 @@ class ParsedSqlNode(ParsedNode): resource_type: NodeType = field(metadata={"restrict": [NodeType.SqlOperation]}) -@dataclass -class ParsedPrqlNode(ParsedNode): - resource_type: NodeType = field(metadata={"restrict": [NodeType.PrqlOperation]}) - - def same_seeds(first: ParsedNode, second: ParsedNode) -> bool: # for seeds, we check the hashes. If the hashes are different types, # no match. If the hashes are both the same 'path', log a warning and diff --git a/core/dbt/node_types.py b/core/dbt/node_types.py index eec59decc5c..f90ccea327c 100644 --- a/core/dbt/node_types.py +++ b/core/dbt/node_types.py @@ -13,7 +13,6 @@ class NodeType(StrEnum): # TODO: rm? RPCCall = "rpc" SqlOperation = "sql operation" - PrqlOperation = "prql operation" Documentation = "docs block" Source = "source" Macro = "macro" @@ -32,7 +31,6 @@ def executable(cls) -> List["NodeType"]: cls.Documentation, cls.RPCCall, cls.SqlOperation, - cls.PrqlOperation, ] @classmethod diff --git a/test/unit/test_node_types.py b/test/unit/test_node_types.py index 3ba3f7a903e..fcfb115b9b9 100644 --- a/test/unit/test_node_types.py +++ b/test/unit/test_node_types.py @@ -10,7 +10,6 @@ NodeType.Seed: "seeds", NodeType.RPCCall: "rpcs", NodeType.SqlOperation: "sql operations", - NodeType.PrqlOperation: "prql operations", NodeType.Documentation: "docs blocks", NodeType.Source: "sources", NodeType.Macro: "macros",