-
Notifications
You must be signed in to change notification settings - Fork 1.6k
/
init.py
312 lines (278 loc) · 13.5 KB
/
init.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
import copy
import os
from pathlib import Path
import re
import shutil
import sys
from typing import Optional
import yaml
import click
import dbt.config
import dbt.clients.system
from dbt.flags import get_flags
from dbt.version import _get_adapter_plugin_names
from dbt.adapters.factory import load_plugin, get_include_paths
from dbt.contracts.util import Identifier as ProjectName
from dbt.events.functions import fire_event
from dbt.events.types import (
StarterProjectPath,
ConfigFolderDirectory,
NoSampleProfileFound,
ProfileWrittenWithSample,
ProfileWrittenWithTargetTemplateYAML,
ProfileWrittenWithProjectTemplateYAML,
SettingUpProfile,
InvalidProfileTemplateYAML,
ProjectNameAlreadyExists,
ProjectCreated,
)
from dbt.include.starter_project import PACKAGE_PATH as starter_project_directory
from dbt.include.global_project import PROJECT_NAME as GLOBAL_PROJECT_NAME
from dbt.task.base import BaseTask, move_to_nearest_project_dir
DOCS_URL = "https://docs.getdbt.com/docs/configure-your-profile"
SLACK_URL = "https://community.getdbt.com/"
# This file is not needed for the starter project but exists for finding the resource path
IGNORE_FILES = ["__init__.py", "__pycache__"]
# https://click.palletsprojects.com/en/8.0.x/api/#types
# click v7.0 has UNPROCESSED, STRING, INT, FLOAT, BOOL, and UUID available.
click_type_mapping = {
"string": click.STRING,
"int": click.INT,
"float": click.FLOAT,
"bool": click.BOOL,
None: None,
}
class InitTask(BaseTask):
def copy_starter_repo(self, project_name):
fire_event(StarterProjectPath(dir=starter_project_directory))
shutil.copytree(
starter_project_directory, project_name, ignore=shutil.ignore_patterns(*IGNORE_FILES)
)
def create_profiles_dir(self, profiles_dir: str) -> bool:
"""Create the user's profiles directory if it doesn't already exist."""
profiles_path = Path(profiles_dir)
if not profiles_path.exists():
fire_event(ConfigFolderDirectory(dir=profiles_dir))
dbt.clients.system.make_directory(profiles_dir)
return True
return False
def create_profile_from_sample(self, adapter: str, profile_name: str):
"""Create a profile entry using the adapter's sample_profiles.yml
Renames the profile in sample_profiles.yml to match that of the project."""
# Line below raises an exception if the specified adapter is not found
load_plugin(adapter)
adapter_path = get_include_paths(adapter)[0]
sample_profiles_path = adapter_path / "sample_profiles.yml"
if not sample_profiles_path.exists():
fire_event(NoSampleProfileFound(adapter=adapter))
else:
with open(sample_profiles_path, "r") as f:
sample_profile = f.read()
sample_profile_name = list(yaml.safe_load(sample_profile).keys())[0]
# Use a regex to replace the name of the sample_profile with
# that of the project without losing any comments from the sample
sample_profile = re.sub(f"^{sample_profile_name}:", f"{profile_name}:", sample_profile)
profiles_filepath = Path(get_flags().PROFILES_DIR) / Path("profiles.yml")
if profiles_filepath.exists():
with open(profiles_filepath, "a") as f:
f.write("\n" + sample_profile)
else:
with open(profiles_filepath, "w") as f:
f.write(sample_profile)
fire_event(
ProfileWrittenWithSample(name=profile_name, path=str(profiles_filepath))
)
def generate_target_from_input(self, profile_template: dict, target: dict = {}) -> dict:
"""Generate a target configuration from profile_template and user input."""
profile_template_local = copy.deepcopy(profile_template)
for key, value in profile_template_local.items():
if key.startswith("_choose"):
choice_type = key[8:].replace("_", " ")
option_list = list(value.keys())
prompt_msg = (
"\n".join([f"[{n+1}] {v}" for n, v in enumerate(option_list)])
+ f"\nDesired {choice_type} option (enter a number)"
)
numeric_choice = click.prompt(prompt_msg, type=click.INT)
choice = option_list[numeric_choice - 1]
# Complete the chosen option's values in a recursive call
target = self.generate_target_from_input(
profile_template_local[key][choice], target
)
else:
if key.startswith("_fixed"):
# _fixed prefixed keys are not presented to the user
target[key[7:]] = value
else:
hide_input = value.get("hide_input", False)
default = value.get("default", None)
hint = value.get("hint", None)
type = click_type_mapping[value.get("type", None)]
text = key + (f" ({hint})" if hint else "")
target[key] = click.prompt(
text, default=default, hide_input=hide_input, type=type
)
return target
def get_profile_name_from_current_project(self) -> str:
"""Reads dbt_project.yml in the current directory to retrieve the
profile name.
"""
with open("dbt_project.yml") as f:
dbt_project = yaml.safe_load(f)
return dbt_project["profile"]
def write_profile(self, profile: dict, profile_name: str):
"""Given a profile, write it to the current project's profiles.yml.
This will overwrite any profile with a matching name."""
# Create the profile directory if it doesn't exist
profiles_filepath = Path(get_flags().PROFILES_DIR) / Path("profiles.yml")
if profiles_filepath.exists():
with open(profiles_filepath, "r+") as f:
profiles = yaml.safe_load(f) or {}
profiles[profile_name] = profile
f.seek(0)
yaml.dump(profiles, f)
f.truncate()
else:
profiles = {profile_name: profile}
with open(profiles_filepath, "w") as f:
yaml.dump(profiles, f)
def create_profile_from_profile_template(self, profile_template: dict, profile_name: str):
"""Create and write a profile using the supplied profile_template."""
initial_target = profile_template.get("fixed", {})
prompts = profile_template.get("prompts", {})
target = self.generate_target_from_input(prompts, initial_target)
target_name = target.pop("target", "dev")
profile = {"outputs": {target_name: target}, "target": target_name}
self.write_profile(profile, profile_name)
def create_profile_from_target(self, adapter: str, profile_name: str):
"""Create a profile without defaults using target's profile_template.yml if available, or
sample_profiles.yml as a fallback."""
# Line below raises an exception if the specified adapter is not found
load_plugin(adapter)
adapter_path = get_include_paths(adapter)[0]
profile_template_path = adapter_path / "profile_template.yml"
if profile_template_path.exists():
with open(profile_template_path) as f:
profile_template = yaml.safe_load(f)
self.create_profile_from_profile_template(profile_template, profile_name)
profiles_filepath = Path(get_flags().PROFILES_DIR) / Path("profiles.yml")
fire_event(
ProfileWrittenWithTargetTemplateYAML(
name=profile_name, path=str(profiles_filepath)
)
)
else:
# For adapters without a profile_template.yml defined, fallback on
# sample_profiles.yml
self.create_profile_from_sample(adapter, profile_name)
def check_if_can_write_profile(self, profile_name: Optional[str] = None) -> bool:
"""Using either a provided profile name or that specified in dbt_project.yml,
check if the profile already exists in profiles.yml, and if so ask the
user whether to proceed and overwrite it."""
profiles_file = Path(get_flags().PROFILES_DIR) / Path("profiles.yml")
if not profiles_file.exists():
return True
profile_name = profile_name or self.get_profile_name_from_current_project()
with open(profiles_file, "r") as f:
profiles = yaml.safe_load(f) or {}
if profile_name in profiles.keys():
response = click.confirm(
f"The profile {profile_name} already exists in "
f"{profiles_file}. Continue and overwrite it?"
)
return response
else:
return True
def create_profile_using_project_profile_template(self, profile_name):
"""Create a profile using the project's profile_template.yml"""
with open("profile_template.yml") as f:
profile_template = yaml.safe_load(f)
self.create_profile_from_profile_template(profile_template, profile_name)
profiles_filepath = Path(get_flags().PROFILES_DIR) / Path("profiles.yml")
fire_event(
ProfileWrittenWithProjectTemplateYAML(name=profile_name, path=str(profiles_filepath))
)
def ask_for_adapter_choice(self) -> str:
"""Ask the user which adapter (database) they'd like to use."""
available_adapters = list(_get_adapter_plugin_names())
prompt_msg = (
"Which database would you like to use?\n"
+ "\n".join([f"[{n+1}] {v}" for n, v in enumerate(available_adapters)])
+ "\n\n(Don't see the one you want? https://docs.getdbt.com/docs/available-adapters)"
+ "\n\nEnter a number"
)
numeric_choice = click.prompt(prompt_msg, type=click.INT)
return available_adapters[numeric_choice - 1]
def get_valid_project_name(self) -> str:
"""Returns a valid project name, either from CLI arg or user prompt."""
name = self.args.project_name
internal_package_names = {GLOBAL_PROJECT_NAME}
available_adapters = list(_get_adapter_plugin_names())
for adapter_name in available_adapters:
internal_package_names.update(f"dbt_{adapter_name}")
while not ProjectName.is_valid(name) or name in internal_package_names:
if name:
click.echo(name + " is not a valid project name.")
name = click.prompt("Enter a name for your project (letters, digits, underscore)")
return name
def run(self):
"""Entry point for the init task."""
profiles_dir = get_flags().PROFILES_DIR
self.create_profiles_dir(profiles_dir)
try:
move_to_nearest_project_dir(self.args.project_dir)
in_project = True
except dbt.exceptions.DbtRuntimeError:
in_project = False
if in_project:
# When dbt init is run inside an existing project,
# just setup the user's profile.
fire_event(SettingUpProfile())
profile_name = self.get_profile_name_from_current_project()
if not self.check_if_can_write_profile(profile_name=profile_name):
return
# If a profile_template.yml exists in the project root, that effectively
# overrides the profile_template.yml for the given target.
profile_template_path = Path("profile_template.yml")
if profile_template_path.exists():
try:
# This relies on a valid profile_template.yml from the user,
# so use a try: except to fall back to the default on failure
self.create_profile_using_project_profile_template(profile_name)
return
except Exception:
fire_event(InvalidProfileTemplateYAML())
adapter = self.ask_for_adapter_choice()
self.create_profile_from_target(adapter, profile_name=profile_name)
return
# When dbt init is run outside of an existing project,
# create a new project and set up the user's profile.
available_adapters = list(_get_adapter_plugin_names())
if not len(available_adapters):
print("No adapters available. Go to https://docs.getdbt.com/docs/available-adapters")
sys.exit(1)
project_name = self.get_valid_project_name()
project_path = Path(project_name)
if project_path.exists():
fire_event(ProjectNameAlreadyExists(name=project_name))
return
self.copy_starter_repo(project_name)
os.chdir(project_name)
with open("dbt_project.yml", "r+") as f:
content = f"{f.read()}".format(project_name=project_name, profile_name=project_name)
f.seek(0)
f.write(content)
f.truncate()
# Ask for adapter only if skip_profile_setup flag is not provided.
if not self.args.skip_profile_setup:
if not self.check_if_can_write_profile(profile_name=project_name):
return
adapter = self.ask_for_adapter_choice()
self.create_profile_from_target(adapter, profile_name=project_name)
fire_event(
ProjectCreated(
project_name=project_name,
docs_url=DOCS_URL,
slack_url=SLACK_URL,
)
)