-
Notifications
You must be signed in to change notification settings - Fork 57
/
commands.py
347 lines (314 loc) · 10.1 KB
/
commands.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
# Copyright (c) 2024 Snowflake Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import logging
import os.path
import typer
from click import ClickException, Context, Parameter # type: ignore
from click.core import ParameterSource # type: ignore
from click.types import StringParamType
from snowflake.cli._plugins.connection.util import (
strip_and_check_if_exists,
strip_if_value_present,
)
from snowflake.cli._plugins.object.manager import ObjectManager
from snowflake.cli.api.cli_global_context import get_cli_context
from snowflake.cli.api.commands.flags import (
PLAIN_PASSWORD_MSG,
)
from snowflake.cli.api.commands.snow_typer import SnowTyperFactory
from snowflake.cli.api.config import (
ConnectionConfig,
add_connection,
connection_exists,
get_all_connections,
get_connection_dict,
get_default_connection_name,
set_config_value,
)
from snowflake.cli.api.console import cli_console
from snowflake.cli.api.constants import ObjectType
from snowflake.cli.api.output.types import (
CollectionResult,
CommandResult,
MessageResult,
ObjectResult,
)
from snowflake.connector import ProgrammingError
from snowflake.connector.config_manager import CONFIG_MANAGER
app = SnowTyperFactory(
name="connection",
help="Manages connections to Snowflake.",
)
log = logging.getLogger(__name__)
class EmptyInput:
def __repr__(self):
return "optional"
class OptionalPrompt(StringParamType):
def convert(self, value, param, ctx):
return None if isinstance(value, EmptyInput) else value
def _mask_password(connection_params: dict):
if "password" in connection_params:
connection_params["password"] = "****"
return connection_params
@app.command(name="list")
def list_connections(**options) -> CommandResult:
"""
Lists configured connections.
"""
connections = get_all_connections()
default_connection = get_default_connection_name()
result = (
{
"connection_name": connection_name,
"parameters": _mask_password(
connection_config.to_dict_of_known_non_empty_values()
),
"is_default": connection_name == default_connection,
}
for connection_name, connection_config in connections.items()
)
return CollectionResult(result)
def require_integer(field_name: str):
def callback(value: str):
if value is None:
return None
if value.strip().isdigit():
return value.strip()
raise ClickException(f"Value of {field_name} must be integer")
return callback
def _password_callback(ctx: Context, param: Parameter, value: str):
if value and ctx.get_parameter_source(param.name) == ParameterSource.COMMANDLINE: # type: ignore
cli_console.warning(PLAIN_PASSWORD_MSG)
return value
@app.command()
def add(
connection_name: str = typer.Option(
None,
"--connection-name",
"-n",
prompt="Name for this connection",
help="Name of the new connection.",
show_default=False,
callback=strip_if_value_present,
),
account: str = typer.Option(
None,
"--account",
"-a",
"--accountname",
prompt="Snowflake account name",
help="Account name to use when authenticating with Snowflake.",
show_default=False,
callback=strip_if_value_present,
),
user: str = typer.Option(
None,
"--user",
"-u",
"--username",
prompt="Snowflake username",
show_default=False,
help="Username to connect to Snowflake.",
callback=strip_if_value_present,
),
password: str = typer.Option(
EmptyInput(),
"--password",
"-p",
click_type=OptionalPrompt(),
callback=_password_callback,
prompt="Snowflake password",
help="Snowflake password.",
hide_input=True,
),
role: str = typer.Option(
EmptyInput(),
"--role",
"-r",
click_type=OptionalPrompt(),
prompt="Role for the connection",
help="Role to use on Snowflake.",
callback=strip_if_value_present,
),
warehouse: str = typer.Option(
EmptyInput(),
"--warehouse",
"-w",
click_type=OptionalPrompt(),
prompt="Warehouse for the connection",
help="Warehouse to use on Snowflake.",
callback=strip_if_value_present,
),
database: str = typer.Option(
EmptyInput(),
"--database",
"-d",
click_type=OptionalPrompt(),
prompt="Database for the connection",
help="Database to use on Snowflake.",
callback=strip_if_value_present,
),
schema: str = typer.Option(
EmptyInput(),
"--schema",
"-s",
click_type=OptionalPrompt(),
prompt="Schema for the connection",
help="Schema to use on Snowflake.",
callback=strip_if_value_present,
),
host: str = typer.Option(
EmptyInput(),
"--host",
"-h",
click_type=OptionalPrompt(),
prompt="Connection host",
help="Host name the connection attempts to connect to Snowflake.",
callback=strip_if_value_present,
),
port: int = typer.Option(
EmptyInput(),
"--port",
"-P",
click_type=OptionalPrompt(),
prompt="Connection port",
help="Port to communicate with on the host.",
callback=require_integer(field_name="port"),
),
region: str = typer.Option(
EmptyInput(),
"--region",
"-R",
click_type=OptionalPrompt(),
prompt="Snowflake region",
help="Region name if not the default Snowflake deployment.",
callback=strip_if_value_present,
),
authenticator: str = typer.Option(
EmptyInput(),
"--authenticator",
"-A",
click_type=OptionalPrompt(),
prompt="Authentication method",
help="Chosen authenticator, if other than password-based",
),
private_key_path: str = typer.Option(
EmptyInput(),
"--private-key",
"-k",
click_type=OptionalPrompt(),
prompt="Path to private key file",
help="Path to file containing private key",
callback=strip_and_check_if_exists,
),
token_file_path: str = typer.Option(
EmptyInput(),
"--token-file-path",
"-t",
click_type=OptionalPrompt(),
prompt="Path to token file",
help="Path to file with an OAuth token that should be used when connecting to Snowflake",
callback=strip_and_check_if_exists,
),
set_as_default: bool = typer.Option(
False,
"--default",
is_flag=True,
help="If provided the connection will be configured as default connection.",
),
**options,
) -> CommandResult:
"""Adds a connection to configuration file."""
if connection_exists(connection_name):
raise ClickException(f"Connection {connection_name} already exists")
add_connection(
connection_name,
ConnectionConfig(
account=account,
user=user,
password=password,
host=host,
region=region,
port=port,
database=database,
schema=schema,
warehouse=warehouse,
role=role,
authenticator=authenticator,
private_key_path=private_key_path,
token_file_path=token_file_path,
),
)
if set_as_default:
set_config_value(
section=None, key="default_connection_name", value=connection_name
)
return MessageResult(
f"Wrote new connection {connection_name} to {CONFIG_MANAGER.file_path}"
)
@app.command(requires_connection=True)
def test(
**options,
) -> CommandResult:
"""
Tests the connection to Snowflake.
"""
# Test connection
cli_context = get_cli_context()
conn = cli_context.connection
# Test session attributes
om = ObjectManager()
try:
# "use database" operation changes schema to default "public",
# so to test schema set up by user we need to copy it here:
schema = conn.schema
if conn.role:
om.use(object_type=ObjectType.ROLE, name=f'"{conn.role}"')
if conn.database:
om.use(object_type=ObjectType.DATABASE, name=f'"{conn.database}"')
if schema:
om.use(object_type=ObjectType.SCHEMA, name=f'"{schema}"')
if conn.warehouse:
om.use(object_type=ObjectType.WAREHOUSE, name=f'"{conn.warehouse}"')
except ProgrammingError as err:
raise ClickException(str(err))
conn_ctx = cli_context.connection_context
result = {
"Connection name": conn_ctx.connection_name,
"Status": "OK",
"Host": conn.host,
"Account": conn.account,
"User": conn.user,
"Role": f'{conn.role or "not set"}',
"Database": f'{conn.database or "not set"}',
"Warehouse": f'{conn.warehouse or "not set"}',
}
if conn_ctx.enable_diag:
result["Diag Report Location"] = os.path.join(
conn_ctx.diag_log_path, "SnowflakeConnectionTestReport.txt"
)
return ObjectResult(result)
@app.command(requires_connection=False)
def set_default(
name: str = typer.Argument(
help="Name of the connection, as defined in your `config.toml`",
show_default=False,
),
**options,
):
"""Changes default connection to provided value."""
get_connection_dict(connection_name=name)
set_config_value(section=None, key="default_connection_name", value=name)
return MessageResult(f"Default connection set to: {name}")