From 7c5919d2bc7b9da262694e04816b335c0541c3df Mon Sep 17 00:00:00 2001 From: Maximilian Roos Date: Sun, 2 Oct 2022 19:08:10 -0700 Subject: [PATCH] Add test to test_graph.py --- core/dbt/parser/_dbt_prql.py | 46 +++++++++++++++++++++++++++++------- test/unit/test_graph.py | 24 ++++++++++++++++++- test/unit/test_parser.py | 8 +++---- 3 files changed, 65 insertions(+), 13 deletions(-) diff --git a/core/dbt/parser/_dbt_prql.py b/core/dbt/parser/_dbt_prql.py index 52c72cac390..6cd762f4778 100644 --- a/core/dbt/parser/_dbt_prql.py +++ b/core/dbt/parser/_dbt_prql.py @@ -13,16 +13,37 @@ if typing.TYPE_CHECKING: from dbt.parser.language_provider import references_type -# import prql_python - -# Always return the same SQL, mocking the prqlc output for a single case which we -# currently use in tests, so we can test this without configuring dependencies. (Obv -# fix as we expand the tests, way before we merge.) +# import prql_python +# This mocks the prqlc output for two cases which we currently use in tests, so we can +# test this without configuring dependencies. (Obv fix as we expand the tests, way +# before we merge.) class prql_python: # type: ignore @staticmethod - def to_sql(prql): - compiled_sql = """ + def to_sql(prql) -> str: + + query_1 = "from employees" + + query_1_compiled = """ +SELECT + employees.* +FROM + employees + """.strip() + + query_2 = """ +from (dbt source.salesforce.in_process) +join (dbt ref.foo.bar) [id] +filter salary > 100 + """.strip() + + query_2_refs_replaced = """ +from (`{{ source('salesforce', 'in_process') }}`) +join (`{{ ref('foo', 'bar') }}`) [id] +filter salary > 100 + """.strip() + + query_2_compiled = """ SELECT "{{ source('salesforce', 'in_process') }}".*, "{{ ref('foo', 'bar') }}".*, @@ -33,7 +54,16 @@ def to_sql(prql): WHERE salary > 100 """.strip() - return compiled_sql + + lookup = dict( + { + query_1: query_1_compiled, + query_2: query_2_compiled, + query_2_refs_replaced: query_2_compiled, + } + ) + + return lookup[prql] logger = logging.getLogger(__name__) diff --git a/test/unit/test_graph.py b/test/unit/test_graph.py index 6efb14ceff1..959b1ceaa4e 100644 --- a/test/unit/test_graph.py +++ b/test/unit/test_graph.py @@ -60,6 +60,9 @@ def setUp(self): # Create file filesystem searcher self.filesystem_search = patch('dbt.parser.read_files.filesystem_search') def mock_filesystem_search(project, relative_dirs, extension, ignore_spec): + # Adding in `and "prql" not in extension` will cause a bunch of tests to + # fail; need to understand more on how these are constructed to debug. + # Possibly `sql not in extension` is a way of having it only run once. if 'sql' not in extension: return [] if 'models' not in relative_dirs: @@ -144,7 +147,7 @@ def use_models(self, models): path = FilePath( searched_path='models', project_root=os.path.normcase(os.getcwd()), - relative_path='{}.sql'.format(k), + relative_path=f'{k}.{lang}', modification_time=0.0, ) # FileHash can't be empty or 'search_key' will be None @@ -328,3 +331,22 @@ def test__partial_parse(self): manifest.metadata.dbt_version = '99999.99.99' is_partial_parsable, _ = loader.is_partial_parsable(manifest) self.assertFalse(is_partial_parsable) + + def test_models_prql(self): + self.use_models({ + 'model_prql':( 'from employees', 'prql'), + }) + + config = self.get_config() + manifest = self.load_manifest(config) + + compiler = self.get_compiler(config) + linker = compiler.compile(manifest) + + self.assertEqual( + list(linker.nodes()), + ['model.test_models_compile.model_prql']) + + self.assertEqual( + list(linker.edges()), + []) diff --git a/test/unit/test_parser.py b/test/unit/test_parser.py index aceb74ee9cb..b159acca86d 100644 --- a/test/unit/test_parser.py +++ b/test/unit/test_parser.py @@ -716,10 +716,10 @@ def test_parse_error(self): def test_parse_prql_file(self): prql_code = """ - from (dbt source.salesforce.in_process) - join (dbt ref.foo.bar) [id] - filter salary > 100 - """ +from (dbt source.salesforce.in_process) +join (dbt ref.foo.bar) [id] +filter salary > 100 + """.strip() block = self.file_block_for(prql_code, 'nested/prql_model.prql') self.parser.manifest.files[block.file.file_id] = block.file self.parser.parse_file(block)