diff --git a/CHANGELOG.md b/CHANGELOG.md index 7813da07e..4e5656681 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,7 @@ ## 0.7.9dev +* [Fix] Fixed `Set` method in `Connection` class to recognize same descriptor with different aliases (#532) * [Fix] Added bottom-padding to the buttons in table explorer. Now they are not hidden by the scrollbar (#540) * [Feature] Modified `histogram` command to support data with NULL values (#176) * [Fix] `psutil` is no longer a dependency for JupySQL ([#541](https://github.com/ploomber/jupysql/issues/541)) diff --git a/doc/api/magic-plot.md b/doc/api/magic-plot.md index 75751c13a..2d85b2c06 100644 --- a/doc/api/magic-plot.md +++ b/doc/api/magic-plot.md @@ -5,7 +5,7 @@ jupytext: extension: .md format_name: myst format_version: 0.13 - jupytext_version: 1.14.5 + jupytext_version: 1.14.6 kernelspec: display_name: Python 3 (ipykernel) language: python @@ -152,10 +152,11 @@ generate histograms without explicitly removing NULL entries. `%sqlplot` returns a `matplotlib.Axes` object. ```{code-cell} ipython3 -ax = %sqlplot histogram --table penguins.csv --column body_mass_g +ax = %sqlplot histogram --table penguins.csv --column body_mass_g ax.set_title("Body mass (grams)") _ = ax.grid() ``` + ## `%sqlplot bar` ```{versionadded} 0.7.6 @@ -196,7 +197,7 @@ You can also pass the orientation using the `orient` argument. ```{code-cell} ipython3 %sqlplot bar --table add_col --column species cnt --with add_col --orient h -``` +``` You can also show the number on top of the bar using the `S`/`show-numbers` argument. @@ -237,6 +238,7 @@ group by species ```{code-cell} ipython3 %sqlplot pie --table add_col --column species cnt --with add_col ``` + Here, `species` is the `labels` column and `cnt` is the `x` column. @@ -244,4 +246,4 @@ You can also show the percentage on top of the pie using the `S`/`show-numbers` ```{code-cell} ipython3 %sqlplot pie --table penguins.csv --column species -S -``` \ No newline at end of file +``` diff --git a/doc/api/magic-sql.md b/doc/api/magic-sql.md index 44f9a2e44..333df4fcd 100644 --- a/doc/api/magic-sql.md +++ b/doc/api/magic-sql.md @@ -5,7 +5,7 @@ jupytext: extension: .md format_name: myst format_version: 0.13 - jupytext_version: 1.14.5 + jupytext_version: 1.14.6 kernelspec: display_name: Python 3 (ipykernel) language: python @@ -136,8 +136,8 @@ Or pass an alias (**added in 0.5.2**): %sql --close db-two ``` - ## Specify creator function + ```{code-cell} ipython3 import os import sqlite3 @@ -146,6 +146,8 @@ import sqlite3 os.environ["DATABASE_URL"] = "sqlite:///" # Define a function that returns a DBAPI connection + + def creator(): return sqlite3.connect("") ``` @@ -154,7 +156,6 @@ def creator(): %sql --creator creator ``` - ## Create table ```{code-cell} ipython3 diff --git a/doc/howto/ggplot-interact.md b/doc/howto/ggplot-interact.md index deee3e6a2..fa3fea5ce 100644 --- a/doc/howto/ggplot-interact.md +++ b/doc/howto/ggplot-interact.md @@ -7,7 +7,7 @@ jupytext: extension: .md format_name: myst format_version: 0.13 - jupytext_version: 1.14.5 + jupytext_version: 1.14.6 kernelspec: display_name: Python 3 (ipykernel) language: python @@ -193,8 +193,11 @@ show_legend = widgets.ToggleButton( ```{code-cell} ipython3 def plot(b, cmap, show_legend): - (ggplot("diamonds", aes(x="price")) + geom_histogram(bins=b, fill="cut", cmap=cmap) - + facet_wrap("color", legend=show_legend)) + ( + ggplot("diamonds", aes(x="price")) + + geom_histogram(bins=b, fill="cut", cmap=cmap) + + facet_wrap("color", legend=show_legend) + ) ``` ```{code-cell} ipython3 diff --git a/src/sql/connection.py b/src/sql/connection.py index 8128866df..d5435055b 100644 --- a/src/sql/connection.py +++ b/src/sql/connection.py @@ -360,7 +360,6 @@ def set(cls, descriptor, displaycon, connect_args=None, creator=None, alias=None if descriptor: is_custom_connection_ = Connection.is_custom_connection(descriptor) - if isinstance(descriptor, Connection): cls.current = descriptor elif isinstance(descriptor, Engine): @@ -369,7 +368,6 @@ def set(cls, descriptor, displaycon, connect_args=None, creator=None, alias=None cls.current = CustomConnection(descriptor, alias=alias) else: existing = rough_dict_get(cls.connections, descriptor) - # NOTE: I added one indentation level, otherwise # the "existing" variable would not exist if # passing an engine object as descriptor. @@ -377,12 +375,20 @@ def set(cls, descriptor, displaycon, connect_args=None, creator=None, alias=None # is that we're missing some unit tests # when descriptor is a connection object # http://docs.sqlalchemy.org/en/rel_0_9/core/engines.html#custom-dbapi-connect-arguments # noqa - cls.current = existing or Connection.from_connect_str( - connect_str=descriptor, - connect_args=connect_args, - creator=creator, - alias=alias, - ) + # if same alias found + if existing and existing.alias == alias: + cls.current = existing + # if just switching connections + elif existing and alias is None: + cls.current = existing + # if new alias connection + elif existing is None or existing.alias != alias: + cls.current = Connection.from_connect_str( + connect_str=descriptor, + connect_args=connect_args, + creator=creator, + alias=alias, + ) else: if cls.connections: @@ -397,7 +403,6 @@ def set(cls, descriptor, displaycon, connect_args=None, creator=None, alias=None ) else: raise cls._error_no_connection() - return cls.current @classmethod diff --git a/src/tests/test_connection.py b/src/tests/test_connection.py index f83c4dfb9..daa6dac08 100644 --- a/src/tests/test_connection.py +++ b/src/tests/test_connection.py @@ -377,3 +377,27 @@ def test_close_all(ip_empty): connections_copy["duckdb://"].execute("").fetchall() assert not Connection.connections + + +@pytest.mark.parametrize( + "old_alias, new_alias", + [ + (None, "duck1"), + ("duck1", "duck2"), + (None, None), + ], +) +def test_new_connection_with_alias(ip_empty, old_alias, new_alias): + """Test if a new connection with the same url but a + new alias is registered for different cases of old alias + """ + ip_empty.run_cell(f"%sql duckdb:// --alias {old_alias}") + ip_empty.run_cell(f"%sql duckdb:// --alias {new_alias}") + table = ip_empty.run_cell("sql --connections").result + if old_alias is None and new_alias is None: + assert new_alias not in table + else: + connection = table[new_alias] + assert connection + assert connection.url == "duckdb://" + assert connection == connection.current