forked from catherinedevlin/ipython-sql
-
Notifications
You must be signed in to change notification settings - Fork 78
/
Copy pathmagic.py
787 lines (686 loc) Β· 25.5 KB
/
magic.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
import json
import re
from pathlib import Path
import sqlparse
try:
from ipywidgets import interact
except ModuleNotFoundError:
interact = None
from ploomber_core.exceptions import modify_exceptions
from IPython.core.magic import (
Magics,
cell_magic,
line_magic,
magics_class,
needs_local_scope,
no_var_expand,
)
from IPython.core.magic_arguments import argument, magic_arguments, parse_argstring
from sqlalchemy.exc import (
OperationalError,
ProgrammingError,
DatabaseError,
StatementError,
)
from traitlets.config.configurable import Configurable
from traitlets import Bool, Int, TraitError, Unicode, Dict, observe, validate
from sql.traits import Parameters
import warnings
import shlex
import sql.connection
import sql.parse
from sql.run.run import run_statements
from sql.parse import _option_strings_from_parser
from sql import display, exceptions
from sql.store import store
from sql.command import SQLCommand
from sql.magic_plot import SqlPlotMagic
from sql.magic_cmd import SqlCmdMagic
from sql._patch import patch_ipython_usage_error
from sql import util
from sql.error_handler import handle_exception
from sql._current import _set_sql_magic
from ploomber_core.dependencies import check_installed
try:
from pandas.core.frame import DataFrame, Series
except ModuleNotFoundError:
DataFrame = None
Series = None
SUPPORT_INTERACTIVE_WIDGETS = ["Checkbox", "Text", "IntSlider", ""]
IF_NOT_SELECT_MESSAGE = "The query is not a SELECT type query and as \
snippets only work with SELECT queries,"
IF_SELECT_MESSAGE = "JupySQL does not support snippet expansion within CTEs yet,"
@magics_class
class RenderMagic(Magics):
"""
%sqlrender magic which prints composed queries
"""
@line_magic
@magic_arguments()
# TODO: only accept one arg
@argument("line", default="", nargs="*", type=str)
@argument(
"-w",
"--with",
type=str,
help="Use a saved query",
action="append",
dest="with_",
)
def sqlrender(self, line):
args = parse_argstring(self.sqlrender, line)
warnings.warn(
"\n'%sqlrender' will be deprecated soon, "
f"please use '%sqlcmd snippets {args.line[0]}' instead. "
"\n\nFor documentation, follow this link : "
"https://jupysql.ploomber.io/en/latest/api/magic-snippets.html#id1",
FutureWarning,
)
return str(store[args.line[0]])
@magics_class
class SqlMagic(Magics, Configurable):
"""Runs SQL statement on a database, specified by SQLAlchemy connect string.
Provides the %%sql magic."""
autocommit = Bool(default_value=True, config=True, help="Set autocommit mode")
autolimit = Int(
default_value=0,
config=True,
allow_none=True,
help="Automatically limit the size of the returned result sets",
)
autopandas = Bool(
default_value=False,
config=True,
help="Return Pandas DataFrames instead of regular result sets",
)
autopolars = Bool(
default_value=False,
config=True,
help="Return Polars DataFrames instead of regular result sets",
)
column_local_vars = Bool(
default_value=False,
config=True,
help="Return data into local variables from column names",
)
displaycon = Bool(
default_value=True, config=True, help="Show connection string after execution"
)
displaylimit = Int(
default_value=10,
config=True,
allow_none=True,
help=(
"Automatically limit the number of rows "
"displayed (full result set is still stored)"
),
)
dsn_filename = Unicode(
default_value=str(Path("~/.jupysql/connections.ini").expanduser()),
config=True,
help="Path to DSN file. "
"When the first argument is of the form [section], "
"a sqlalchemy connection string is formed from the "
"matching section in the DSN file.",
)
feedback = Int(
default_value=1,
config=True,
help="Verbosity level. 0=minimal, 1=normal, 2=all",
)
lazy_execution = Bool(
default_value=False,
config=True,
help="Whether to evaluate using ResultSet which will "
"cause the plan to execute or just return a lazily "
"executed plan allowing validating schemas, "
"without expensive compute."
"Currently only supported for Spark Connection.",
)
named_parameters = Parameters(
default_value="warn",
config=True,
help=(
"Allow named parameters in queries "
"(i.e., 'SELECT * FROM foo WHERE bar = :bar')"
),
)
polars_dataframe_kwargs = Dict(
default_value={},
config=True,
help=(
"Polars DataFrame constructor keyword arguments"
"(e.g. infer_schema_length, nan_to_null, schema_overrides, etc)"
),
)
short_errors = Bool(
default_value=True,
config=True,
help="Don't display the full traceback on SQL Programming Error",
)
style = Unicode(
default_value="DEFAULT",
config=True,
help=(
"Set the table printing style to any of prettytable's "
"defined styles (currently DEFAULT, MSWORD_FRIENDLY, PLAIN_COLUMNS, "
"RANDOM, SINGLE_BORDER, DOUBLE_BORDER, MARKDOWN )"
),
)
def __init__(self, shell):
self._store = store
Configurable.__init__(self, config=shell.config)
Magics.__init__(self, shell=shell)
# Add ourself to the list of module configurable via %config
self.shell.configurables.append(self)
@validate("dsn_filename")
def _valid_dsn_filename(self, proposal):
path = Path(proposal["value"]).expanduser()
return str(path)
# To verify displaylimit is valid positive integer
# If:
# None -> We treat it as 0 (no limit)
# Positive Integer -> Pass
# Negative Integer -> raise Error
@validate("displaylimit")
def _valid_displaylimit(self, proposal):
if proposal["value"] is None:
display.message("displaylimit: Value None will be treated as 0 (no limit)")
return 0
try:
value = int(proposal["value"])
if value < 0:
raise TraitError(
"{}: displaylimit cannot be a negative integer".format(value)
)
return value
except ValueError:
raise TraitError("{}: displaylimit is not an integer".format(value))
@observe("autopandas", "autopolars")
def _mutex_autopandas_autopolars(self, change):
# When enabling autopandas or autopolars, automatically disable the
# other one in case it was enabled and print a warning
if change["new"]:
other = "autopolars" if change["name"] == "autopandas" else "autopandas"
if getattr(self, other):
setattr(self, other, False)
display.message(
f"Disabled '{other}' since '{change['name']}' was enabled."
)
def check_random_arguments(self, line="", cell=""):
# check only for cell magic
if cell != "":
tokens = shlex.split(line, posix=False)
arguments = []
# Iterate through the tokens to separate arguments and SQL code
# If the token starts with "--", it is an argument
breakLoop = False
for token in tokens:
if token.startswith("--") or token.startswith("-"):
arguments.append(token)
breakLoop = True
else:
if breakLoop:
break
declared_argument = _option_strings_from_parser(SqlMagic.execute.parser)
for check_argument in arguments:
if check_argument not in declared_argument:
raise exceptions.UsageError(
"Unrecognized argument(s): {}".format(check_argument)
)
@no_var_expand
@needs_local_scope
@line_magic("sql")
@cell_magic("sql")
@line_magic("jupysql")
@cell_magic("jupysql")
@magic_arguments()
@argument("line", default="", nargs="*", type=str, help="sql")
@argument(
"-l", "--connections", action="store_true", help="list active connections"
)
@argument("-x", "--close", type=str, help="close a session by name")
@argument(
"-c", "--creator", type=str, help="specify creator function for new connection"
)
@argument(
"-s",
"--section",
type=str,
help="section of dsn_file to be used for generating a connection string",
)
@argument(
"-p",
"--persist",
action="store_true",
help="create a table name in the database from the named DataFrame",
)
@argument(
"-P",
"--persist-replace",
action="store_true",
help="replace the DataFrame if it exists, otherwise perform --persist",
)
@argument(
"-n",
"--no-index",
action="store_true",
help="Do not store Data Frame index when persisting",
)
@argument(
"--append",
action="store_true",
help=(
"create, or append to, a table name in the database from the "
"named DataFrame"
),
)
@argument(
"-a",
"--connection_arguments",
type=str,
help="specify dictionary of connection arguments to pass to SQL driver",
)
@argument("-f", "--file", type=str, help="Run SQL from file at this path")
@argument("-S", "--save", type=str, help="Save this query for later use")
@argument(
"-w",
"--with",
type=str,
help="Use a saved query",
action="append",
dest="with_",
)
@argument(
"-N",
"--no-execute",
action="store_true",
help="Do not execute query (use it with --save)",
)
@argument(
"-A",
"--alias",
type=str,
help="Assign an alias to the connection",
)
@argument(
"--interact",
type=str,
action="append",
help="Interactive mode",
)
def execute(self, line="", cell="", local_ns=None):
"""
Runs SQL statement against a database, specified by
SQLAlchemy connect string.
If no database connection has been established, first word
should be a SQLAlchemy connection string, or the user@db name
of an established connection.
Examples::
%%sql postgresql://me:mypw@localhost/mydb
SELECT * FROM mytable
%%sql me@mydb
DELETE FROM mytable
%%sql
DROP TABLE mytable
SQLAlchemy connect string syntax examples:
postgresql://me:mypw@localhost/mydb
sqlite://
mysql+pymysql://me:mypw@localhost/mydb
"""
return self._execute(
line=line, cell=cell, local_ns=local_ns, is_interactive_mode=False
)
@modify_exceptions
def _execute(self, line, cell, local_ns, is_interactive_mode=False):
"""
This function implements the cell logic; we create this private
method so we can control how the function is called. Otherwise,
decorating ``SqlMagic.execute`` will break when adding the ``@log_call``
decorator with ``payload=True``
NOTE: telemetry has been removed, we can remove this function
"""
def interactive_execute_wrapper(**kwargs):
for key, value in kwargs.items():
local_ns[key] = value
return self._execute(line, cell, local_ns, is_interactive_mode=True)
# line is the text after the magic, cell is the cell's body
# Examples
# %sql {line}
# note that line magic has no body
# %%sql {line}
# {cell}
self.check_random_arguments(line, cell)
if local_ns is None:
local_ns = {}
# save globals and locals so they can be referenced in bind vars
user_ns = self.shell.user_ns.copy()
user_ns.update(local_ns)
command = SQLCommand(self, user_ns, line, cell)
# args.line: contains the line after the magic with all options removed
args = command.args
if util.is_rendering_required(line):
util.expand_args(args, user_ns)
if args.section and args.alias:
raise exceptions.UsageError(
"Cannot use --section with --alias since the section name "
"is automatically set as the connection alias"
)
is_cte = command.sql_original.strip().lower().startswith("with ")
# only expand CTE if this is not a CTE itself
if not is_cte:
if args.with_:
with_ = args.with_
else:
with_ = self._store.infer_dependencies(command.sql_original, args.save)
if with_:
query_type = get_query_type(command.sql_original)
if query_type != "SELECT":
display.message_warning(
f"Your query is using the following snippets: \
{', '.join(with_)}. {IF_NOT_SELECT_MESSAGE} CTE generation is disabled"
)
else:
command.set_sql_with(with_)
display.message(
f"Generating CTE with stored snippets: \
{util.pretty_print(with_)}"
)
else:
with_ = None
else:
query_type = get_query_type(command.sql_original)
if args.with_:
raise exceptions.UsageError(
"Cannot use --with with CTEs, remove --with and re-run the cell"
)
dependencies = self._store.infer_dependencies(
command.sql_original, args.save
)
if dependencies:
if query_type != "SELECT":
display_message = IF_NOT_SELECT_MESSAGE
else:
display_message = IF_SELECT_MESSAGE
display.message_warning(
f"Your query is using one or more of the following snippets: \
{', '.join(dependencies)}. {display_message}\
CTE generation is disabled"
)
with_ = None
# Create the interactive slider
if args.interact and not is_interactive_mode:
check_installed(["ipywidgets"], "--interactive argument")
interactive_dict = {}
for key in args.interact:
interactive_dict[key] = local_ns[key]
display.message(
"Interactive mode, please interact with below "
"widget(s) to control the variable"
)
interact(interactive_execute_wrapper, **interactive_dict)
return
if args.connections:
return sql.connection.ConnectionManager.connections_table()
elif args.close:
return sql.connection.ConnectionManager.close_connection_with_descriptor(
args.close
)
connect_arg = command.connection
if args.section:
connect_arg = sql.parse.connection_str_from_dsn_section(args.section, self)
if args.connection_arguments:
try:
# check for string deliniators, we need to strip them for json parse
raw_args = args.connection_arguments
if len(raw_args) > 1:
targets = ['"', "'"]
head = raw_args[0]
tail = raw_args[-1]
if head in targets and head == tail:
raw_args = raw_args[1:-1]
args.connection_arguments = json.loads(raw_args)
except Exception as e:
raise exceptions.ValueError(str(e)) from e
else:
args.connection_arguments = {}
if args.creator:
args.creator = user_ns[args.creator]
# this creates a new connection or use an existing one
# depending on the connect_arg value
conn = sql.connection.ConnectionManager.set(
connect_arg,
displaycon=self.displaycon,
connect_args=args.connection_arguments,
creator=args.creator,
alias=args.section if args.section else args.alias,
config=self,
)
if args.persist_replace and args.append:
raise exceptions.UsageError(
"""You cannot simultaneously persist and append data to a dataframe;
please choose to utilize either one or the other."""
)
if args.persist and args.persist_replace:
warnings.warn("Please use either --persist or --persist-replace")
return self._persist_dataframe(
command.sql,
conn,
user_ns,
append=False,
index=not args.no_index,
replace=True,
)
elif args.persist:
return self._persist_dataframe(
command.sql, conn, user_ns, append=False, index=not args.no_index
)
elif args.persist_replace:
return self._persist_dataframe(
command.sql,
conn,
user_ns,
append=False,
index=not args.no_index,
replace=True,
)
if args.append:
return self._persist_dataframe(
command.sql, conn, user_ns, append=True, index=not args.no_index
)
if not command.sql:
return
# store the query if needed
if args.save:
if "-" in args.save:
warnings.warn(
"Using hyphens will be deprecated soon, "
"please use "
+ str(args.save.replace("-", "_"))
+ " instead for the save argument.",
FutureWarning,
)
self._store.store(args.save, command.sql_original, with_=with_)
if args.no_execute:
display.message("Skipping execution...")
return
parameters = None
if self.named_parameters == "disabled":
parameters = {}
elif self.named_parameters == "enabled":
parameters = user_ns
try:
result = run_statements(conn, command.sql, self, parameters=parameters)
if (
result is not None
and not isinstance(result, str)
and self.column_local_vars
):
# Instead of returning values, set variables directly in the
# users namespace. Variable names given by column names
if self.autopandas or self.autopolars:
keys = result.keys()
else:
keys = result.keys
result = result.dict()
if self.feedback:
display.message(
"Returning data to local variables [{}]".format(", ".join(keys))
)
self.shell.user_ns.update(result)
return None
else:
if command.result_var:
self.shell.user_ns.update({command.result_var: result})
if command.return_result_var:
return result
return None
# Return results into the default ipython _ variable
return result
# JA: added DatabaseError for MySQL
except (
ProgrammingError,
OperationalError,
DatabaseError,
# raised when they query has :parameters but no parameters are given
StatementError,
) as e:
# Sqlite apparently return all errors as OperationalError :/
handle_exception(e, command.sql, self.short_errors)
except Exception as e:
# Handle non SQLAlchemy errors
handle_exception(e, command.sql, self.short_errors)
legal_sql_identifier = re.compile(r"^[A-Za-z0-9#_$]+")
@modify_exceptions
def _persist_dataframe(
self, raw, conn, user_ns, append=False, index=True, replace=False
):
"""Implements PERSIST, which writes a DataFrame to the RDBMS"""
if not DataFrame:
raise exceptions.MissingPackageError(
"You must install pandas to persist results: pip install pandas"
)
frame_name = raw.strip(";")
# user may pass schema.dataframe (required for certain DBs
# like Trino)
schema_name = None
if "." in frame_name:
schema_frame = frame_name.split(".")
schema_name = schema_frame[0]
frame_name = schema_frame[1]
# invalid identifier
if not frame_name.isidentifier():
raise exceptions.UsageError(
f"Expected {frame_name!r} to be a pd.DataFrame but it's"
" not a valid identifier"
)
# missing argument
if not frame_name:
raise exceptions.UsageError(
"Missing argument: %sql --persist <name_of_data_frame>"
)
# undefined variable
if frame_name not in user_ns:
raise exceptions.UsageError(
f"Expected {frame_name!r} to be a pd.DataFrame but it's undefined"
)
frame = user_ns[frame_name]
if not isinstance(frame, DataFrame) and not isinstance(frame, Series):
raise exceptions.TypeError(
f"{frame_name!r} is not a Pandas DataFrame or Series"
)
# Make a suitable name for the resulting database table
table_name = frame_name.lower()
table_name = self.legal_sql_identifier.search(table_name).group(0)
if replace:
if_exists = "replace"
elif append:
if_exists = "append"
else:
if_exists = "fail"
conn.to_table(
table_name=table_name,
data_frame=frame,
if_exists=if_exists,
index=index,
schema=schema_name,
)
def get_query_type(command: str):
"""
Returns the query type of the original sql command
"""
parsed = sqlparse.parse(command)
query_type = parsed[0].get_type() if parsed else None
if query_type == "UNKNOWN":
return None
return query_type
def set_configs(ip, file_path, alternate_path):
"""Set user defined SqlMagic configuration settings"""
sql = ip.find_cell_magic("sql").__self__
user_configs, loaded_from = util.get_user_configs(file_path, alternate_path)
default_configs = util.get_default_configs(sql)
table_rows = []
success = False
if user_configs:
for config, value in user_configs.items():
if config in default_configs.keys():
default_type = type(default_configs[config])
if isinstance(value, default_type):
setattr(sql, config, value)
table_rows.append([config, value])
success = True
else:
display.message(
f"'{value}' is an invalid value for '{config}'. "
f"Please use {default_type.__name__} value instead."
)
else:
util.find_close_match_config(config, default_configs.keys())
if success:
if loaded_from is not None:
display.message(f"Loading configurations from {loaded_from}.")
else:
display.message("Loading default configurations.")
return table_rows
def load_SqlMagic_configs(ip):
"""Loads saved SqlMagic configs in pyproject.toml or ~/.jupysql/config"""
file_path = util.find_path_from_root("pyproject.toml")
alternate_path = Path("~/.jupysql/config").expanduser()
table_rows = []
try:
table_rows = set_configs(ip, file_path, alternate_path)
except Exception as e:
if type(e).__name__ == "TomlDecodeError":
display.message_warning(
f"Could not load configuration file at {file_path}"
f"{(' or ' + str(alternate_path)) if alternate_path else ''}"
" (default configuration will be used).\nPlease "
f"check that it is valid TOML: {e}"
)
return
if type(e).__name__ == "ModuleNotFoundError":
display.message(
"The 'toml' package isn't installed. To load settings from "
"pyproject.toml or ~/.jupysql/config, install with: "
"pip install toml"
)
return
else:
raise
if table_rows:
display.message("Settings changed:")
display.table(["Config", "value"], table_rows)
def load_ipython_extension(ip):
"""Load the magics, this function is executed when the user runs: %load_ext sql"""
sql_magic = SqlMagic(ip)
_set_sql_magic(sql_magic)
ip.register_magics(sql_magic)
load_SqlMagic_configs(ip)
# start the default connection if the user has one in their config file
sql.connection.ConnectionManager.load_default_connection_from_file_if_any(
config=sql_magic
)
ip.register_magics(RenderMagic)
ip.register_magics(SqlPlotMagic)
ip.register_magics(SqlCmdMagic)
patch_ipython_usage_error(ip)