Skip to content

Commit

Permalink
Merge pull request #1549 from heisencoder/feature/add-project-dir-flag
Browse files Browse the repository at this point in the history
add --project-dir flag to allow specifying project directory
  • Loading branch information
drewbanin authored Jun 19, 2019
2 parents 9ad8512 + f5c3300 commit f834446
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 17 deletions.
10 changes: 10 additions & 0 deletions core/dbt/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,16 @@ def run_from_args(parsed):
def _build_base_subparser():
base_subparser = argparse.ArgumentParser(add_help=False)

base_subparser.add_argument(
'--project-dir',
default=None,
type=str,
help="""
Which directory to look in for the dbt_project.yml file.
Default is the current working directory and its parents.
"""
)

base_subparser.add_argument(
'--profiles-dir',
default=PROFILES_DIR,
Expand Down
31 changes: 20 additions & 11 deletions core/dbt/task/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,19 @@ def interpret_results(self, results):
return True


def get_nearest_project_dir():
def get_nearest_project_dir(args):
# If the user provides an explicit project directory, use that
# but don't look at parent directories.
if args.project_dir:
project_file = os.path.join(args.project_dir, "dbt_project.yml")
if os.path.exists(project_file):
return args.project_dir
else:
raise dbt.exceptions.RuntimeException(
"fatal: Invalid --project-dir flag. Not a dbt project. "
"Missing dbt_project.yml file"
)

root_path = os.path.abspath(os.sep)
cwd = os.getcwd()

Expand All @@ -102,24 +114,21 @@ def get_nearest_project_dir():
return cwd
cwd = os.path.dirname(cwd)

return None

raise dbt.exceptions.RuntimeException(
"fatal: Not a dbt project (or any of the parent directories). "
"Missing dbt_project.yml file"
)

def move_to_nearest_project_dir():
nearest_project_dir = get_nearest_project_dir()
if nearest_project_dir is None:
raise dbt.exceptions.RuntimeException(
"fatal: Not a dbt project (or any of the parent directories). "
"Missing dbt_project.yml file"
)

def move_to_nearest_project_dir(args):
nearest_project_dir = get_nearest_project_dir(args)
os.chdir(nearest_project_dir)


class RequiresProjectTask(BaseTask):
@classmethod
def from_args(cls, args):
move_to_nearest_project_dir()
move_to_nearest_project_dir(args)
return super(RequiresProjectTask, cls).from_args(args)


Expand Down
43 changes: 37 additions & 6 deletions test/unit/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@
from dbt.adapters.redshift import RedshiftCredentials
from dbt.contracts.project import PackageConfig
from dbt.semver import VersionSpecifier
from dbt.task.run_operation import RunOperationTask


INITIAL_ROOT = os.getcwd()


@contextmanager
Expand Down Expand Up @@ -65,7 +69,7 @@ def temp_cd(path):

class Args(object):
def __init__(self, profiles_dir=None, threads=None, profile=None,
cli_vars=None, version_check=None):
cli_vars=None, version_check=None, project_dir=None):
self.profile = profile
if threads is not None:
self.threads = threads
Expand All @@ -75,11 +79,13 @@ def __init__(self, profiles_dir=None, threads=None, profile=None,
self.vars = cli_vars
if version_check is not None:
self.version_check = version_check
if project_dir is not None:
self.project_dir = project_dir


class BaseConfigTest(unittest.TestCase):
"""Subclass this, and before calling the superclass setUp, set
profiles_dir.
self.profiles_dir and self.project_dir.
"""
def setUp(self):
self.default_project_data = {
Expand Down Expand Up @@ -147,7 +153,7 @@ def setUp(self):
}
}
self.args = Args(profiles_dir=self.profiles_dir, cli_vars='{}',
version_check=True)
version_check=True, project_dir=self.project_dir)
self.env_override = {
'env_value_type': 'postgres',
'env_value_host': 'env-postgres-host',
Expand Down Expand Up @@ -176,7 +182,7 @@ def tearDown(self):
except EnvironmentError:
pass

def proejct_path(self, name):
def project_path(self, name):
return os.path.join(self.project_dir, name)

def profile_path(self, name):
Expand All @@ -185,11 +191,11 @@ def profile_path(self, name):
def write_project(self, project_data=None):
if project_data is None:
project_data = self.project_data
with open(self.proejct_path('dbt_project.yml'), 'w') as fp:
with open(self.project_path('dbt_project.yml'), 'w') as fp:
yaml.dump(project_data, fp)

def write_packages(self, package_data):
with open(self.proejct_path('packages.yml'), 'w') as fp:
with open(self.project_path('packages.yml'), 'w') as fp:
yaml.dump(package_data, fp)

def write_profile(self, profile_data=None):
Expand All @@ -202,6 +208,7 @@ def write_profile(self, profile_data=None):
class TestProfile(BaseConfigTest):
def setUp(self):
self.profiles_dir = '/invalid-path'
self.project_dir = '/invalid-project-path'
super(TestProfile, self).setUp()

def from_raw_profiles(self):
Expand Down Expand Up @@ -928,6 +935,30 @@ def test_with_invalid_package(self):
dbt.config.Project.from_project_root(self.project_dir, {})


class TestRunOperationTask(BaseFileTest):
def setUp(self):
super(TestRunOperationTask, self).setUp()
self.write_project(self.default_project_data)
self.write_profile(self.default_profile_data)

def tearDown(self):
super(TestRunOperationTask, self).tearDown()
# These tests will change the directory to the project path,
# so it's necessary to change it back at the end.
os.chdir(INITIAL_ROOT)

def test_run_operation_task(self):
self.assertEqual(os.getcwd(), INITIAL_ROOT)
self.assertNotEqual(INITIAL_ROOT, self.project_dir)
new_task = RunOperationTask.from_args(self.args)
self.assertEqual(os.getcwd(), self.project_dir)

def test_run_operation_task_with_bad_path(self):
self.args.project_dir = 'bad_path'
with self.assertRaises(dbt.exceptions.RuntimeException):
new_task = RunOperationTask.from_args(self.args)


class TestVariableProjectFile(BaseFileTest):
def setUp(self):
super(TestVariableProjectFile, self).setUp()
Expand Down

0 comments on commit f834446

Please sign in to comment.