Skip to content

Commit

Permalink
util changed
Browse files Browse the repository at this point in the history
  • Loading branch information
neelasha23 committed Jan 21, 2024
1 parent a6af9c7 commit e280f63
Show file tree
Hide file tree
Showing 4 changed files with 8 additions and 36 deletions.
5 changes: 1 addition & 4 deletions src/sql/magic.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,6 @@
from sql.error_handler import handle_exception
from sql._current import _set_sql_magic

from jinja2 import Template


from ploomber_core.dependencies import check_installed


Expand Down Expand Up @@ -405,7 +402,7 @@ def interactive_execute_wrapper(**kwargs):
user_ns = self.shell.user_ns.copy()
user_ns.update(local_ns)

line = Template(line).render(user_ns)
line = util.expand_args(line, user_ns)

command = SQLCommand(self, user_ns, line, cell)
# args.line: contains the line after the magic with all options removed
Expand Down
6 changes: 2 additions & 4 deletions src/sql/magic_cmd.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import sys
import argparse
import shlex
from jinja2 import Template

from IPython.core.magic import Magics, line_magic, magics_class, no_var_expand
from IPython.core.magic_arguments import argument, magic_arguments
Expand All @@ -14,7 +13,7 @@
from sql.cmd.snippets import snippets
from sql.cmd.connect import connect
from sql.connection import ConnectionManager
from sql.util import check_duplicate_arguments
from sql.util import check_duplicate_arguments, expand_args

try:
from traitlets.config.configurable import Configurable
Expand Down Expand Up @@ -74,8 +73,7 @@ def _validate_execute_inputs(self, line):
if line == "":
raise exceptions.UsageError(VALID_COMMANDS_MSG)
else:
user_ns = self.shell.user_ns.copy()
line = Template(line).render(user_ns)
line = expand_args(line, self.shell.user_ns.copy())
# directly use shlex since SqlCmdMagic does not use magic_args from parse.py
split = shlex.split(line, posix=False)
command, others = split[0].strip(), split[1:]
Expand Down
4 changes: 1 addition & 3 deletions src/sql/magic_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,12 +82,10 @@ def execute(self, line="", cell="", local_ns=None):
Plot magic
"""

user_ns = self.shell.user_ns.copy()
util.expand_args(line, self.shell.user_ns.copy())

cmd = SQLPlotCommand(self, line)

util.expand_args_cmd(cmd.args, user_ns)

if len(cmd.args.column) == 1:
column = cmd.args.column[0]
else:
Expand Down
29 changes: 4 additions & 25 deletions src/sql/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -578,40 +578,19 @@ def enclose_table_with_double_quotations(table, conn):
return _table


def expand_args_cmd(args, user_ns):
def expand_args(line, user_ns):
"""
Function to substitute command line arguments
with variables defined by user in the IPython
kernel.
Parameters
----------
args : argparse.Namespace,
object to hold the command line arguments.
line : str,
input text after the magic
user_ns : dict,
User namespace of IPython kernel
"""

for attribute in vars(args):
value = getattr(args, attribute)
if value:
if isinstance(value, list):
substituted_value = []
for item in value:
if (
isinstance(item, str)
and item.startswith("{{")
and item.endswith("}}")
):
substituted_value.append(Template(item).render(user_ns))
else:
substituted_value.append(item)
setattr(args, attribute, substituted_value)
else:
if (
isinstance(value, str)
and value.startswith("{{")
and value.endswith("}}")
):
setattr(args, attribute, Template(value).render(user_ns))
return Template(line).render(user_ns)

0 comments on commit e280f63

Please sign in to comment.