diff --git a/CHANGELOG.md b/CHANGELOG.md index 3738a9bb6..21a910d88 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,20 +2,19 @@ ## 0.7.7dev +* [Doc] Hiding connection string when passing `--alias` when opening a connection (#432) +* [Doc] Fix `api/magic-sql.md` since it incorrectly stated that listing functions was `--list`, but it's `--connections` (#432) +* [Feature] Clearer message display when executing queries, listing connections and persisting data frames (#432) +* [Feature] `%sql --connections` now displays an HTML table in Jupyter and a text-based table in the terminal * [Fix] Fix CTE generation when the snippets have trailing semicolons - * [Doc] Added Howto documentation for enabling JupyterLab cell runtime display ([#448](https://github.com/ploomber/jupysql/issues/448)) ## 0.7.6 (2023-05-29) * [Feature] Add `%sqlcmd explore` to explore tables interactively ([#330](https://github.com/ploomber/jupysql/issues/330)) - * [Feature] Support for printing capture variables using `=<<` syntax (by [@jorisroovers](https://github.com/jorisroovers)) - * [Feature] Adds `--persist-replace` argument to replace existing tables when persisting data frames ([#440](https://github.com/ploomber/jupysql/issues/440)) - * [Fix] Fix error when checking if custom connection was PEP 249 Compliant ([#517](https://github.com/ploomber/jupysql/issues/517)) - * [Doc] documenting how to manage connections with `Connection` object ([#282](https://github.com/ploomber/jupysql/issues/282)) * [Feature] Github Codespace (Devcontainer) support for development (by [@jorisroovers](https://github.com/jorisroovers)) ([#484](https://github.com/ploomber/jupysql/issues/484)) diff --git a/doc/api/magic-sql.md b/doc/api/magic-sql.md index 3228e2793..82b04c0ff 100644 --- a/doc/api/magic-sql.md +++ b/doc/api/magic-sql.md @@ -111,7 +111,7 @@ To make all subsequent queries to use certain connection, pass the connection na You can inspect which is the current active connection: ```{code-cell} ipython3 -%sql --list +%sql --connections ``` For more details on managing connections, see [Switch connections](../howto.md#switch-connections). @@ -121,7 +121,7 @@ For more details on managing connections, see [Switch connections](../howto.md#s ## List connections ```{code-cell} ipython3 -%sql --list +%sql --connections ``` ## Close connection diff --git a/doc/community/developer-guide.md b/doc/community/developer-guide.md index 69344f969..90ccea9da 100644 --- a/doc/community/developer-guide.md +++ b/doc/community/developer-guide.md @@ -35,6 +35,27 @@ After the codespace has finished setting up, you can run `conda activate jupysql +++ +## Displaying messages + +```{important} +Use the `sql.display` module instead of `print` for showing feedback to the user. +``` + +You can use `message` (contextual information) and `message_success` (successful operations) to show feedback to the user. Here's an example: + +```{code-cell} ipython3 +from sql.display import message, message_success +``` + +```{code-cell} ipython3 +message("Some information") +``` + +```{code-cell} ipython3 +message_success("Some operation finished successfully!") +``` + + ## Throwing errors When writing Python libraries, we often throw errors (and display error tracebacks) to let users know that something went wrong. However, JupySQL is an abstraction for executing SQL queries; hence, Python tracebacks a useless to end-users since they expose JupySQL's internals. diff --git a/doc/howto.md b/doc/howto.md index 94ed08452..c548fa90f 100644 --- a/doc/howto.md +++ b/doc/howto.md @@ -389,3 +389,26 @@ import warnings warnings.filterwarnings("ignore", category=FutureWarning) ``` + +```{code-cell} ipython3 +conns = %sql --connections +conns["db-three"] +``` + +## Hide connection string + +If you want to hide the connection string, pass an alias + +```{code-cell} ipython3 +%sql --close duckdb:// +``` + +```{code-cell} ipython3 +%sql duckdb:// --alias myconnection +``` + +The alias will be displayed instead of the connection string: + +```{code-cell} ipython3 +%sql SELECT * FROM 'penguins.csv' LIMIT 3 +``` diff --git a/src/sql/command.py b/src/sql/command.py index 2e3eeb252..2f4ad7093 100644 --- a/src/sql/command.py +++ b/src/sql/command.py @@ -20,6 +20,9 @@ class SQLCommand: """ def __init__(self, magic, user_ns, line, cell) -> None: + self._line = line + self._cell = cell + self.args = parse.magic_args(magic.execute, line) # self.args.line (everything that appears after %sql/%%sql in the first line) # is split in tokens (delimited by spaces), this checks if we have one arg @@ -103,3 +106,9 @@ def return_result_var(self): def _var_expand(self, sql, user_ns, magic): return Template(sql).render(user_ns) + + def __repr__(self) -> str: + return ( + f"{type(self).__name__}(line={self._line!r}, cell={self._cell!r}) -> " + f"({self.sql!r}, {self.sql_original!r})" + ) diff --git a/src/sql/connection.py b/src/sql/connection.py index 166d0ec90..c6f2e327a 100644 --- a/src/sql/connection.py +++ b/src/sql/connection.py @@ -10,7 +10,7 @@ from sql.store import store from sql.telemetry import telemetry -from sql import exceptions +from sql import exceptions, display from sql.error_message import detail from ploomber_core.exceptions import modify_exceptions @@ -153,6 +153,30 @@ class Connection: ---------- engine: sqlalchemy.engine.Engine The SQLAlchemy engine to use + + Attributes + ---------- + alias : str or None + The alias passed in the constructor + + engine : sqlalchemy.engine.Engine + The SQLAlchemy engine passed to the constructor + + name : str + A name to identify the connection: {user}@{database_name} + + metadata : Metadata or None + An SQLAlchemy Metadata object (if using SQLAlchemy 2, this is None), + used to retrieve connection information + + url : str + An obfuscated connection string (password hidden) + + dialect : sqlalchemy dialect + A SQLAlchemy dialect object + + session : sqlalchemy session + A SQLAlchemy session object """ # the active connection @@ -162,26 +186,26 @@ class Connection: connections = {} def __init__(self, engine, alias=None): - self.url = engine.url - self.name = self.assign_name(engine) - self.dialect = self.url.get_dialect() + self.alias = alias self.engine = engine + self.name = self.assign_name(engine) if IS_SQLALCHEMY_ONE: self.metadata = sqlalchemy.MetaData(bind=engine) + else: + self.metadata = None - url = ( + self.url = ( repr(sqlalchemy.MetaData(bind=engine).bind.url) if IS_SQLALCHEMY_ONE else repr(engine.url) ) - self.session = self._create_session(engine, url) - - self.connections[alias or url] = self + self.dialect = engine.url.get_dialect() + self.session = self._create_session(engine, self.url) + self.connections[alias or self.url] = self self.connect_args = None - self.alias = alias Connection.current = self @classmethod @@ -362,8 +386,7 @@ def set(cls, descriptor, displaycon, connect_args=None, creator=None, alias=None else: if cls.connections: if displaycon: - # display list of connections - print(cls.connection_list()) + cls.display_current_connection() elif os.getenv("DATABASE_URL"): cls.current = Connection.from_connect_str( connect_str=os.getenv("DATABASE_URL"), @@ -382,27 +405,53 @@ def assign_name(cls, engine): return name @classmethod - def connection_list(cls): - """Returns the list of connections, appending '*' to the current one""" - result = [] + def _get_connections(cls): + """ + Return a list of dictionaries + """ + connections = [] + for key in sorted(cls.connections): conn = cls.connections[key] - if cls.is_custom_connection(conn): - engine_url = conn.url - else: - engine_url = conn.metadata.bind.url if IS_SQLALCHEMY_ONE else conn.url + current = conn == cls.current - prefix = "* " if conn == cls.current else " " + connections.append( + { + "current": current, + "key": key, + "url": conn.url, + "alias": conn.alias, + "connection": conn, + } + ) - if conn.alias: - repr_ = f"{prefix} ({conn.alias}) {engine_url!r}" - else: - repr_ = f"{prefix} {engine_url!r}" + return connections - result.append(repr_) + @classmethod + def display_current_connection(cls): + for conn in cls._get_connections(): + if conn["current"]: + alias = conn.get("alias") + if alias: + display.message(f"Running query in {alias!r}") + else: + display.message(f"Running query in {conn['url']!r}") - return "\n".join(result) + @classmethod + def connections_table(cls): + """Returns the current connections as a table""" + connections = cls._get_connections() + + def map_values(d): + d["current"] = "*" if d["current"] else "" + d["alias"] = d["alias"] if d["alias"] else "" + return d + + return display.ConnectionsTable( + headers=["current", "url", "alias"], + rows_maps=[map_values(c) for c in connections], + ) @classmethod def close(cls, descriptor): @@ -426,6 +475,15 @@ def close(cls, descriptor): ) conn.session.close() + @classmethod + def close_all(cls): + """Close all active connections""" + connections = Connection.connections.copy() + for key, conn in connections.items(): + conn.close(key) + + cls.connections = {} + def is_custom_connection(conn=None) -> bool: """ Checks if given connection is custom diff --git a/src/sql/display.py b/src/sql/display.py new file mode 100644 index 000000000..c0c2e7f6c --- /dev/null +++ b/src/sql/display.py @@ -0,0 +1,92 @@ +""" +A module to display confirmation messages and contextual information to the user +""" +import html + +from prettytable import PrettyTable +from IPython.display import display + + +class Table: + """Provides a txt and html representation of tabular data""" + + TITLE = "" + + def __init__(self, headers, rows) -> None: + self._headers = headers + self._rows = rows + self._table = PrettyTable() + self._table.field_names = headers + + for row in rows: + self._table.add_row(row) + + self._table_html = self._table.get_html_string() + self._table_txt = self._table.get_string() + + def __repr__(self) -> str: + return self.TITLE + "\n" + self._table_txt + + def _repr_html_(self) -> str: + return self.TITLE + "\n" + self._table_html + + +class ConnectionsTable(Table): + TITLE = "Active connections:" + + def __init__(self, headers, rows_maps) -> None: + def get_values(d): + d = {k: v for k, v in d.items() if k not in {"connection", "key"}} + return list(d.values()) + + rows = [get_values(r) for r in rows_maps] + + self._mapping = {} + + for row in rows_maps: + self._mapping[row["key"]] = row["connection"] + + super().__init__(headers=headers, rows=rows) + + def __getitem__(self, key: str): + """ + This method is provided for backwards compatibility. Before + creating ConnectionsTable, `%sql --connections` returned a dictionary, + hence users could retrieve connections using __getitem__. Note that this + was undocumented so we might decide to remove it in the future. + """ + return self._mapping[key] + + def __iter__(self): + """Also provided for backwards compatibility""" + for key in self._mapping: + yield key + + def __len__(self): + """Also provided for backwards compatibility""" + return len(self._mapping) + + +class Message: + """Message for the user""" + + def __init__(self, message, style=None) -> None: + self._message = message + self._message_html = html.escape(message) + self._style = "" or style + + def _repr_html_(self): + return f'{self._message_html}' + + def __repr__(self) -> str: + return self._message + + +def message(message): + """Display a generic message""" + display(Message(message)) + + +def message_success(message): + """Display a success message""" + display(Message(message, style="color: green")) diff --git a/src/sql/magic.py b/src/sql/magic.py index 2e06ff27a..24e8277fa 100644 --- a/src/sql/magic.py +++ b/src/sql/magic.py @@ -21,7 +21,7 @@ import sql.connection import sql.parse import sql.run -from sql import exceptions +from sql import display, exceptions from sql.store import store from sql.command import SQLCommand from sql.magic_plot import SqlPlotMagic @@ -336,7 +336,7 @@ def interactive_execute_wrapper(**kwargs): interact(interactive_execute_wrapper, **interactive_dict) return if args.connections: - return sql.connection.Connection.connections + return sql.connection.Connection.connections_table() elif args.close: return sql.connection.Connection.close(args.close) @@ -374,6 +374,7 @@ def interactive_execute_wrapper(**kwargs): alias=args.alias, ) payload["connection_info"] = conn._get_curr_sqlalchemy_connection_info() + if args.persist_replace and args.append: raise exceptions.UsageError( """You cannot simultaneously persist and append data to a dataframe; @@ -422,7 +423,7 @@ def interactive_execute_wrapper(**kwargs): self._store.store(args.save, command.sql_original, with_=args.with_) if args.no_execute: - print("Skipping execution...") + display.message("Skipping execution...") return try: @@ -537,7 +538,7 @@ def _persist_dataframe( --persist-replace to drop the table before persisting the data frame""" ) - return "Persisted %s" % table_name + display.message_success(f"Success! Persisted {table_name} to the database.") def load_ipython_extension(ip): diff --git a/src/sql/run.py b/src/sql/run.py index 301bb2361..b9f0e9436 100644 --- a/src/sql/run.py +++ b/src/sql/run.py @@ -11,7 +11,7 @@ import sqlalchemy import sqlparse from sql.connection import Connection -from sql import exceptions +from sql import exceptions, display from .column_guesser import ColumnGuesserMixin try: @@ -365,12 +365,9 @@ def csv(self, filename=None, **format_params): return outfile.getvalue() -def interpret_rowcount(rowcount): - if rowcount < 0: - result = "Done." - else: - result = "%d rows affected." % rowcount - return result +def display_affected_rowcount(rowcount): + if rowcount > 0: + display.message_success(f"{rowcount} rows affected.") class FakeResultProxy(object): @@ -551,7 +548,7 @@ def run(conn, sql, config): if result and config.feedback: if hasattr(result, "rowcount"): - print(interpret_rowcount(result.rowcount)) + display_affected_rowcount(result.rowcount) # bypass ResultSet and use duckdb's native method to return a pandas data frame if duckdb_autopandas: diff --git a/src/tests/conftest.py b/src/tests/conftest.py index 3225b6a65..63ce88830 100644 --- a/src/tests/conftest.py +++ b/src/tests/conftest.py @@ -56,6 +56,7 @@ def ip_empty(): ip_session.register_magics(SqlPlotMagic) ip_session.register_magics(SqlCmdMagic) yield ip_session + Connection.close_all() @pytest.fixture diff --git a/src/tests/test_connection.py b/src/tests/test_connection.py index 52d1d0b35..f83c4dfb9 100644 --- a/src/tests/test_connection.py +++ b/src/tests/test_connection.py @@ -3,6 +3,8 @@ import pytest from sqlalchemy import create_engine from sqlalchemy.engine import Engine +from sqlalchemy.exc import ResourceClosedError + import sql.connection from sql.connection import Connection, CustomConnection from IPython.core.error import UsageError @@ -33,7 +35,10 @@ def mock_postgres(monkeypatch, cleanup): def test_password_isnt_displayed(mock_postgres): Connection.from_connect_str("postgresql://user:topsecret@somedomain.com/db") - assert "topsecret" not in Connection.connection_list() + table = Connection.connections_table() + + assert "topsecret" not in str(table) + assert "topsecret" not in table._repr_html_() def test_connection_name(mock_postgres): @@ -269,6 +274,56 @@ def test_no_current_connection_and_get_info(monkeypatch, mock_database): assert conn._get_curr_sqlalchemy_connection_info() is None +def test_get_connections(): + Connection(engine=create_engine("sqlite://")) + Connection(engine=create_engine("duckdb://")) + + assert Connection._get_connections() == [ + { + "url": "duckdb://", + "current": True, + "alias": None, + "key": "duckdb://", + "connection": ANY, + }, + { + "url": "sqlite://", + "current": False, + "alias": None, + "key": "sqlite://", + "connection": ANY, + }, + ] + + +def test_display_current_connection(capsys): + Connection(engine=create_engine("duckdb://")) + Connection.display_current_connection() + + captured = capsys.readouterr() + assert captured.out == "Running query in 'duckdb://'\n" + + +def test_connections_table(): + Connection(engine=create_engine("sqlite://")) + Connection(engine=create_engine("duckdb://")) + + connections = Connection.connections_table() + assert connections._headers == ["current", "url", "alias"] + assert connections._rows == [["*", "duckdb://", ""], ["", "sqlite://", ""]] + + +def test_properties(mock_postgres): + conn = Connection.from_connect_str("postgresql://user:topsecret@somedomain.com/db") + + assert "topsecret" not in conn.url + assert "***" in conn.url + assert conn.name == "user@db" + assert isinstance(conn.engine, Engine) + assert conn.dialect + assert conn.session + + class dummy_connection: def __init__(self): self.engine_name = "dummy_engine" @@ -305,3 +360,20 @@ def close(self): def test_custom_connection(conn, expected): is_custom = Connection.is_custom_connection(conn) assert is_custom == expected + + +def test_close_all(ip_empty): + ip_empty.run_cell("%sql duckdb://") + ip_empty.run_cell("%sql sqlite://") + + connections_copy = Connection.connections.copy() + + Connection.close_all() + + with pytest.raises(ResourceClosedError): + connections_copy["sqlite://"].execute("").fetchall() + + with pytest.raises(ResourceClosedError): + connections_copy["duckdb://"].execute("").fetchall() + + assert not Connection.connections diff --git a/src/tests/test_display.py b/src/tests/test_display.py new file mode 100644 index 000000000..08651f6be --- /dev/null +++ b/src/tests/test_display.py @@ -0,0 +1,8 @@ +from sql import display + + +def test_html_escaping(): + message = display.Message("<>") + + assert "<>" in str(message) + assert "<>" in message._repr_html_() diff --git a/src/tests/test_magic.py b/src/tests/test_magic.py index 2b378880c..c811bbff6 100644 --- a/src/tests/test_magic.py +++ b/src/tests/test_magic.py @@ -462,13 +462,6 @@ def test_connection_args_single_quotes(ip): assert "timeout" in result.result["sqlite:///:memory:"].connect_args -@pytest.mark.skipif(platform.system() == "Windows", reason="failing on windows") -def test_connection_args_double_quotes(ip): - ip.run_cell('%sql --connection_arguments "{\\"timeout\\": 10}" sqlite:///:memory:') - result = ip.run_cell("%sql --connections") - assert "timeout" in result.result["sqlite:///:memory:"].connect_args - - # TODO: support # @with_setup(_setup_author, _teardown_author) # def test_persist_with_connection_info(): diff --git a/src/tests/test_magic_display.py b/src/tests/test_magic_display.py new file mode 100644 index 000000000..c30ed7e75 --- /dev/null +++ b/src/tests/test_magic_display.py @@ -0,0 +1,57 @@ +def test_connection_string_displayed(ip_empty, capsys): + ip_empty.run_cell("%sql duckdb://") + ip_empty.run_cell("%sql show tables") + + captured = capsys.readouterr() + assert "Running query in 'duckdb://'" in captured.out + + +def test_custom_connection_display(ip_empty, capsys, tmp_empty): + ip_empty.run_cell("import duckdb") + ip_empty.run_cell("custom = duckdb.connect('anotherdb')") + ip_empty.run_cell("%sql custom") + ip_empty.run_cell("%sql show tables") + + captured = capsys.readouterr() + assert "Running query in '