Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

A very-WIP implementation of the PRQL plugin, for discussion #5982

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
1 change: 1 addition & 0 deletions core/dbt/contracts/files.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,7 @@ class SourceFile(BaseSourceFile):
docs: List[str] = field(default_factory=list)
macros: List[str] = field(default_factory=list)
env_vars: List[str] = field(default_factory=list)
language: str = "sql"

@classmethod
def big_seed(cls, path: FilePath) -> "SourceFile":
Expand Down
2 changes: 1 addition & 1 deletion core/dbt/graph/selector_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def __post_init__(self):
def default_method(cls, value: str) -> MethodName:
if _probably_path(value):
return MethodName.Path
elif value.lower().endswith((".sql", ".py", ".csv")):
elif value.lower().endswith((".sql", ".py", ".csv", ".prql")):
return MethodName.File
else:
return MethodName.FQN
Expand Down
1 change: 1 addition & 0 deletions core/dbt/node_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,3 +68,4 @@ class RunHookType(StrEnum):
class ModelLanguage(StrEnum):
python = "python"
sql = "sql"
prql = "prql"
167 changes: 167 additions & 0 deletions core/dbt/parser/_dbt_prql.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
"""
This will be in the `dbt-prql` package, but including here during inital code review, so
we can test it without coordinating dependencies.
"""

from __future__ import annotations

import logging
import re
from collections import defaultdict
import typing

if typing.TYPE_CHECKING:
from dbt.parser.language_provider import references_type


# import prql_python
# This mocks the prqlc output for two cases which we currently use in tests, so we can
# test this without configuring dependencies. (Obv fix as we expand the tests, way
# before we merge.)
max-sixty marked this conversation as resolved.
Show resolved Hide resolved
class prql_python: # type: ignore
@staticmethod
def to_sql(prql) -> str:

query_1 = "from employees"

query_1_compiled = """
SELECT
employees.*
FROM
employees
""".strip()

query_2 = """
from (dbt source.salesforce.in_process)
join (dbt ref.foo.bar) [id]
filter salary > 100
""".strip()

query_2_refs_replaced = """
from (`{{ source('salesforce', 'in_process') }}`)
join (`{{ ref('foo', 'bar') }}`) [id]
filter salary > 100
""".strip()

query_2_compiled = """
SELECT
"{{ source('salesforce', 'in_process') }}".*,
"{{ ref('foo', 'bar') }}".*,
id
FROM
{{ source('salesforce', 'in_process') }}
JOIN {{ ref('foo', 'bar') }} USING(id)
WHERE
salary > 100
""".strip()

lookup = dict(
{
query_1: query_1_compiled,
query_2: query_2_compiled,
query_2_refs_replaced: query_2_compiled,
}
)

return lookup[prql]


logger = logging.getLogger(__name__)

word_regex = r"[\w\.\-_]+"
references_regex = rf"\bdbt `?(\w+)\.({word_regex})\.({word_regex})`?"


def compile(prql: str, references: references_type):
"""
>>> print(compile(
... "from (dbt source.salesforce.in_process) | join (dbt ref.foo.bar) [id]",
... references=dict(
... source={('salesforce', 'in_process'): 'salesforce_schema.in_process_tbl'},
... ref={('foo', 'bar'): 'foo_schema.bar_tbl'}
... )
... ))
SELECT
"{{ source('salesforce', 'in_process') }}".*,
"{{ ref('foo', 'bar') }}".*,
id
FROM
{{ source('salesforce', 'in_process') }}
JOIN {{ ref('foo', 'bar') }} USING(id)
"""

# references = list_references(prql)
prql = _hack_sentinels_of_model(prql)

sql = prql_python.to_sql(prql)

# The intention was to replace the sentinels with table names. (Which would be done
# by dbt-prql). But we now just pass the SQL off and let dbt handle it; something
# expedient but not elegant.
# sql = replace_faux_jinja(sql, references)
return sql


def list_references(prql):
"""
List all references (e.g. sources / refs) in a given block.

We need to decide:

— What should prqlc return given `dbt source.foo.bar`, so dbt-prql can find the
references?
— Should it just fill in something that looks like jinja for expediancy? (We
don't support jinja though)

>>> references = list_references("from (dbt source.salesforce.in_process) | join (dbt ref.foo.bar)")
>>> dict(references)
{'source': [('salesforce', 'in_process')], 'ref': [('foo', 'bar')]}
"""

out = defaultdict(list)
for t, package, model in _hack_references_of_prql_query(prql):
out[t] += [(package, model)]

return out


def _hack_references_of_prql_query(prql) -> list[tuple[str, str, str]]:
"""
List the references in a prql query.

This would be implemented by prqlc.

>>> _hack_references_of_prql_query("from (dbt source.salesforce.in_process) | join (dbt ref.foo.bar)")
[('source', 'salesforce', 'in_process'), ('ref', 'foo', 'bar')]
"""

return re.findall(references_regex, prql)


def _hack_sentinels_of_model(prql: str) -> str:
"""
Replace the dbt calls with a jinja-like sentinel.

This will be done by prqlc...

>>> _hack_sentinels_of_model("from (dbt source.salesforce.in_process) | join (dbt ref.foo.bar) [id]")
"from (`{{ source('salesforce', 'in_process') }}`) | join (`{{ ref('foo', 'bar') }}`) [id]"
"""
return re.sub(references_regex, r"`{{ \1('\2', '\3') }}`", prql)


def replace_faux_jinja(sql: str, references: references_type):
"""
>>> print(replace_faux_jinja(
... "SELECT * FROM {{ source('salesforce', 'in_process') }}",
... references=dict(source={('salesforce', 'in_process'): 'salesforce_schema.in_process_tbl'})
... ))
SELECT * FROM salesforce_schema.in_process_tbl

"""
for ref_type, lookups in references.items():
for (package, table), literal in lookups.items():
prql_new = sql.replace(f"{{{{ {ref_type}('{package}', '{table}') }}}}", literal)
sql = prql_new

return sql
21 changes: 20 additions & 1 deletion core/dbt/parser/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ def _mangle_hooks(self, config):
config[key] = [hooks.get_hook_dict(h) for h in config[key]]

def _create_error_node(
self, name: str, path: str, original_file_path: str, raw_code: str, language: str = "sql"
self, name: str, path: str, original_file_path: str, raw_code: str, language: str
) -> UnparsedNode:
"""If we hit an error before we've actually parsed a node, provide some
level of useful information by attaching this to the exception.
Expand Down Expand Up @@ -192,6 +192,8 @@ def _create_parsetime_node(
name = block.name
if block.path.relative_path.endswith(".py"):
language = ModelLanguage.python
elif block.path.relative_path.endswith(".prql"):
language = ModelLanguage.prql
else:
# this is not ideal but we have a lot of tests to adjust if don't do it
language = ModelLanguage.sql
Expand Down Expand Up @@ -225,6 +227,7 @@ def _create_parsetime_node(
path=path,
original_file_path=block.path.original_file_path,
raw_code=block.contents,
language=language,
)
raise ParsingException(msg, node=node)

Expand Down Expand Up @@ -414,6 +417,22 @@ def parse_file(self, file_block: FileBlock) -> None:
self.parse_node(file_block)


# TOOD: Currently `ModelParser` inherits from `SimpleSQLParser`, which inherits from
# SQLParser. So possibly `ModelParser` should instead be `SQLModelParser`, and in
# `ManifestLoader.load`, we should resolve which to use? (I think that's how it works;
# though then why all the inheritance and generics if every model is parsed with `ModelParser`?)
# class PRQLParser(
# ConfiguredParser[FileBlock, IntermediateNode, FinalNode], Generic[IntermediateNode, FinalNode]
# ):
Comment on lines +420 to +426
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It sounds like this piece is the biggest current challenge, and the thing that would require the most refactoring to make pluggable.

Currently, SimpleSQLParser is used in all of these places:

  • ModelParser (even for Python models)
  • AnalysisParser
  • SingularTestParser
  • SqlBlockParser (for semantic layer queries)
  • SeedParser (!), even though render_with_context is a total no-op

We shouldn't be looking to create new node types for each language — these should still be models, analyses, tests, etc — but it's clear that we need some way to allow a language plugin to define its own parser, and then provide it to resources matching that language input.

It also looks like the real Jinja-SQL entrypoint is ConfiguredParser.render_with_context. We'd probably want to split that out, to live on a dedicated JinjaSQLParser instead!

Caveats:

  • There's some additional Jinja business in update_parsed_node_config_dict and update_parsed_node_name, to respect rules around hooks and database/schema/alias — those have more to do with model configuration than model language, and the fact that we've stretched Jinja into a "rules engine" for configuration that goes beyond its more-straightforwardly understood role as a SQL template.
  • I'd like to punt on snapshots, which are defined as Jinja blocks (and mostly config, anyway)
  • I don't know if we should even think/talk about macros, at this early point, but we know we'll want some way to define and reuse functions

# The full mro is:
# dbt.parser.models.ModelParser,
# dbt.parser.base.SimpleSQLParser,
# dbt.parser.base.SQLParser,
# dbt.parser.base.ConfiguredParser,
# dbt.parser.base.Parser,
# dbt.parser.base.BaseParser,


class SimpleSQLParser(SQLParser[FinalNode, FinalNode]):
def transform(self, node):
return node
61 changes: 61 additions & 0 deletions core/dbt/parser/language_provider.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
from __future__ import annotations


# try:
# import dbt_prql # type: ignore
# except ImportError:
# dbt_prql = None

# dict of ref_type (e.g. source, ref) -> (dict of (package, table) -> literal)
# references_type = dict[str, dict[tuple[str, str], str]]

# I can't get the above to work on CI; I had thought that it was fine with
# from __future__ import annotations, but it seems not. So, we'll just use Dict.
from typing import Dict, Tuple

references_type = Dict[str, Dict[Tuple[str, str], str]]


class LanguageProvider:
"""
A LanguageProvider is a class that can parse a given language.

Currently implemented only for PRQL, but we could extend this to other languages (in
the medium term)

TODO: See notes in `ModelParser.render_update`; the current implementation has some
missmatches.
"""

# def compile(self, code: str) -> ParsedNode:
def compile(self, code: str) -> str:
"""
Compile a given block into a ParsedNode.
"""
raise NotImplementedError("compile")

def list_references(self, code: str) -> references_type:
"""
List all references (e.g. sources / refs) in a given block.
"""
raise NotImplementedError("list_references")


class PrqlProvider:
def __init__(self) -> None:
# TODO: Uncomment when dbt-prql is released
# if not dbt_prql:
# raise ImportError(
# "dbt_prql is required and not found; try running `pip install dbt_prql`"
# )
pass

def compile(self, code: str, references: dict) -> str:
from . import _dbt_prql as dbt_prql

return dbt_prql.compile(code, references=references)

def list_references(self, code: str) -> references_type:
from . import _dbt_prql as dbt_prql

return dbt_prql.list_references(code)
21 changes: 21 additions & 0 deletions core/dbt/parser/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from dbt.dataclass_schema import ValidationError
from dbt.exceptions import ParsingException, validator_error_message, UndefinedMacroException

from .language_provider import PrqlProvider

dbt_function_key_words = set(["ref", "source", "config", "get"])
dbt_function_full_names = set(["dbt.ref", "dbt.source", "dbt.config", "dbt.config.get"])
Expand Down Expand Up @@ -233,6 +234,26 @@ def render_update(self, node: ParsedModelNode, config: ContextConfig) -> None:
raise ParsingException(msg, node=node) from exc
return

if node.language == ModelLanguage.prql:
provider = PrqlProvider()
try:
context = self._context_for(node, config)
references = provider.list_references(node.raw_code)
sql = provider.compile(node.raw_code, references)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After taking another look here — with plans to do some hacking on/around this soon! — I think we'd want the actual .compile() call (PRQL → SQL transpilation) to happen within compilation, rather than during parsing. Here's where we do that for Python.

I think the goal here, within the parser, should just be to provide the set of references, and then pick up again at compile time. Something to also think about: Certain modeling languages will want a database connection at compile time (e.g. Jinja-SQL, Cody's prototype work on Ibis); some won't need one (e.g. Snowpark/PySpark Python, PRQL), and so shouldn't require one.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK great, that seems very reasonable.

Would you have it as an if in Compiler like the current python implementation, or inherit a class for each lang; i.e. CompilerPython / CompilerPrql?

# This is currently kinda a hack; I'm not happy about it.
#
# The original intention was to use the result of `.compile` and pass
# that to the Model. But currently the parsing seems extremely coupled
# to the `ModelParser` object. So possibly we should instead inherit
# from that, and there should be one of those for each language?
node.raw_code = sql
node.language = ModelLanguage.sql

except ValidationError as exc:
# we got a ValidationError - probably bad types in config()
msg = validator_error_message(exc)
raise ParsingException(msg, node=node) from exc

elif not flags.STATIC_PARSER:
# jinja rendering
super().render_update(node, config)
Expand Down
2 changes: 1 addition & 1 deletion core/dbt/parser/read_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ def read_files(project, files, parser_files, saved_files):
project,
files,
project.model_paths,
[".sql", ".py"],
[".sql", ".py", ".prql"],
ParseFileType.Model,
saved_files,
dbt_ignore_spec,
Expand Down
1 change: 1 addition & 0 deletions core/dbt/parser/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,7 @@ def get_hashable_md(data: Union[str, int, float, List, Dict]) -> Union[str, List
path=path,
original_file_path=target.original_file_path,
raw_code=raw_code,
language="sql",
)
raise ParsingException(msg, node=node) from exc

Expand Down
Loading
Loading