diff --git a/src/sql/magic.py b/src/sql/magic.py index 0384e6077..2f25c4252 100644 --- a/src/sql/magic.py +++ b/src/sql/magic.py @@ -43,9 +43,6 @@ from sql.error_handler import handle_exception from sql._current import _set_sql_magic -from jinja2 import Template - - from ploomber_core.dependencies import check_installed @@ -405,7 +402,7 @@ def interactive_execute_wrapper(**kwargs): user_ns = self.shell.user_ns.copy() user_ns.update(local_ns) - line = Template(line).render(user_ns) + line = util.expand_args(line, user_ns) command = SQLCommand(self, user_ns, line, cell) # args.line: contains the line after the magic with all options removed diff --git a/src/sql/magic_cmd.py b/src/sql/magic_cmd.py index 163a062e3..024856ae4 100644 --- a/src/sql/magic_cmd.py +++ b/src/sql/magic_cmd.py @@ -1,7 +1,6 @@ import sys import argparse import shlex -from jinja2 import Template from IPython.core.magic import Magics, line_magic, magics_class, no_var_expand from IPython.core.magic_arguments import argument, magic_arguments @@ -14,7 +13,7 @@ from sql.cmd.snippets import snippets from sql.cmd.connect import connect from sql.connection import ConnectionManager -from sql.util import check_duplicate_arguments +from sql.util import check_duplicate_arguments, expand_args try: from traitlets.config.configurable import Configurable @@ -74,8 +73,7 @@ def _validate_execute_inputs(self, line): if line == "": raise exceptions.UsageError(VALID_COMMANDS_MSG) else: - user_ns = self.shell.user_ns.copy() - line = Template(line).render(user_ns) + line = expand_args(line, self.shell.user_ns.copy()) # directly use shlex since SqlCmdMagic does not use magic_args from parse.py split = shlex.split(line, posix=False) command, others = split[0].strip(), split[1:] diff --git a/src/sql/magic_plot.py b/src/sql/magic_plot.py index 53170e98f..8bf77836f 100644 --- a/src/sql/magic_plot.py +++ b/src/sql/magic_plot.py @@ -82,12 +82,10 @@ def execute(self, line="", cell="", local_ns=None): Plot magic """ - user_ns = self.shell.user_ns.copy() + util.expand_args(line, self.shell.user_ns.copy()) cmd = SQLPlotCommand(self, line) - util.expand_args_cmd(cmd.args, user_ns) - if len(cmd.args.column) == 1: column = cmd.args.column[0] else: diff --git a/src/sql/util.py b/src/sql/util.py index 31e501598..2c3837569 100644 --- a/src/sql/util.py +++ b/src/sql/util.py @@ -578,7 +578,7 @@ def enclose_table_with_double_quotations(table, conn): return _table -def expand_args_cmd(args, user_ns): +def expand_args(line, user_ns): """ Function to substitute command line arguments with variables defined by user in the IPython @@ -586,32 +586,11 @@ def expand_args_cmd(args, user_ns): Parameters ---------- - args : argparse.Namespace, - object to hold the command line arguments. + line : str, + input text after the magic user_ns : dict, User namespace of IPython kernel """ - for attribute in vars(args): - value = getattr(args, attribute) - if value: - if isinstance(value, list): - substituted_value = [] - for item in value: - if ( - isinstance(item, str) - and item.startswith("{{") - and item.endswith("}}") - ): - substituted_value.append(Template(item).render(user_ns)) - else: - substituted_value.append(item) - setattr(args, attribute, substituted_value) - else: - if ( - isinstance(value, str) - and value.startswith("{{") - and value.endswith("}}") - ): - setattr(args, attribute, Template(value).render(user_ns)) + return Template(line).render(user_ns)