Skip to content

Commit

Permalink
upgrades for prettytable (#191)
Browse files Browse the repository at this point in the history
  • Loading branch information
Palashio authored Mar 3, 2023
1 parent 5199b51 commit bed8b36
Show file tree
Hide file tree
Showing 10 changed files with 26 additions and 33 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# CHANGELOG

## 0.6.1dev
* [Fix] Adds support for prettytable 2.0

## 0.6.0 (2023-02-27)

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
)

install_requires = [
"prettytable<1",
"prettytable",
"ipython>=1.0",
"sqlalchemy>=0.6.7,<2.0",
"sqlparse",
Expand Down
7 changes: 5 additions & 2 deletions src/sql/column_guesser.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,16 +28,19 @@ class ColumnGuesserMixin(object):
pie: ... y
"""

def __init__(self):
self.keys = None

def _build_columns(self):
self.columns = [Column() for col in self.keys]
for row in self:
for (col_idx, col_val) in enumerate(row):
for col_idx, col_val in enumerate(row):
col = self.columns[col_idx]
col.append(col_val)
if (col_val is not None) and (not is_quantity(col_val)):
col.is_quantity = False

for (idx, key_name) in enumerate(self.keys):
for idx, key_name in enumerate(self.keys):
self.columns[idx].name = key_name

self.x = Column()
Expand Down
2 changes: 1 addition & 1 deletion src/sql/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def get_missing_package_suggestion_str(e):
module_name, MISSING_PACKAGE_LIST_EXCEPT_MATCHERS.keys()
)
if close_matches:
return f"Perhaps you meant to use driver the dialect: \"{close_matches[0]}\""
return f'Perhaps you meant to use driver the dialect: "{close_matches[0]}"'
# Not found
return (
suggestion_prefix + "make sure you are using correct driver name:\n"
Expand Down
3 changes: 3 additions & 0 deletions src/sql/magic.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,9 @@ def _execute(self, payload, line, cell, local_ns):
# %%sql {line}
# {cell}

if local_ns is None:
local_ns = {}

# save globals and locals so they can be referenced in bind vars
user_ns = self.shell.user_ns.copy()
user_ns.update(local_ns)
Expand Down
1 change: 0 additions & 1 deletion src/sql/parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ def connection_from_dsn_section(section, config):


def _connection_string(s, config):

s = expandvars(s) # for environment variables
if "@" in s or "://" in s:
return s
Expand Down
8 changes: 2 additions & 6 deletions src/sql/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,9 +244,7 @@ def boxplot(payload, table, column, *, orient="v", with_=None, conn=None):
if not conn:
conn = sql.connection.Connection.current.session

payload[
"connection_info"
] = sql.connection.Connection._get_curr_connection_info()
payload["connection_info"] = sql.connection.Connection._get_curr_connection_info()

ax = plt.gca()
vert = orient == "v"
Expand Down Expand Up @@ -328,9 +326,7 @@ def histogram(payload, table, column, bins, with_=None, conn=None):
.. plot:: ../examples/plot_histogram_many.py
"""
ax = plt.gca()
payload[
"connection_info"
] = sql.connection.Connection._get_curr_connection_info()
payload["connection_info"] = sql.connection.Connection._get_curr_connection_info()
if isinstance(column, str):
bin_, height = _histogram(table, column, bins, with_=with_, conn=conn)
ax.bar(bin_, height, align="center", width=bin_[-1] - bin_[-2])
Expand Down
20 changes: 9 additions & 11 deletions src/sql/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,21 +101,19 @@ class ResultSet(list, ColumnGuesserMixin):
Can access rows listwise, or by string value of leftmost column.
"""

def __init__(self, sqlaproxy, sql, config):
self.keys = sqlaproxy.keys()
self.sql = sql
def __init__(self, sqlaproxy, config):
self.config = config
self.limit = config.autolimit
style_name = config.style
self.style = prettytable.__dict__[style_name.upper()]
self.keys = {}
if sqlaproxy.returns_rows:
if self.limit:
list.__init__(self, sqlaproxy.fetchmany(size=self.limit))
self.keys = sqlaproxy.keys()
if config.autolimit:
list.__init__(self, sqlaproxy.fetchmany(size=config.autolimit))
else:
list.__init__(self, sqlaproxy.fetchall())
self.field_names = unduplicate_field_names(self.keys)
self.pretty = PrettyTable(self.field_names, style=self.style)
# self.pretty.set_style(self.style)
self.pretty = PrettyTable(
self.field_names, style=prettytable.__dict__[config.style.upper()]
)
else:
list.__init__(self, [])
self.pretty = None
Expand Down Expand Up @@ -405,7 +403,7 @@ def run(conn, sql, config, user_namespace):
_commit(conn=conn, config=config)
if result and config.feedback:
print(interpret_rowcount(result.rowcount))
resultset = ResultSet(result, statement, config)
resultset = ResultSet(result, config)
if config.autopandas:
return resultset.DataFrame()
elif config.autopolars:
Expand Down
5 changes: 4 additions & 1 deletion src/tests/test_magic.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,22 +255,25 @@ def test_autopolars(ip):
dframe = runsql(ip, "SELECT * FROM test;")

import polars as pl

assert type(dframe) == pl.DataFrame
assert not dframe.is_empty()
assert len(dframe.shape) == 2
assert dframe['name'][0] == "foo"
assert dframe["name"][0] == "foo"


def test_mutex_autopolars_autopandas(ip):
dframe = runsql(ip, "SELECT * FROM test;")
assert type(dframe) == ResultSet

import polars as pl

ip.run_line_magic("config", "SqlMagic.autopolars = True")
dframe = runsql(ip, "SELECT * FROM test;")
assert type(dframe) == pl.DataFrame

import pandas as pd

ip.run_line_magic("config", "SqlMagic.autopandas = True")
dframe = runsql(ip, "SELECT * FROM test;")
assert type(dframe) == pd.DataFrame
Expand Down
10 changes: 0 additions & 10 deletions src/tests/test_parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,6 @@ class DummyConfig:


def test_connection_from_dsn_section():

result = connection_from_dsn_section(section="DB_CONFIG_1", config=DummyConfig())
assert result == "postgres://goesto11:seentheelephant@my.remote.host:5432/pgmain"
result = connection_from_dsn_section(section="DB_CONFIG_2", config=DummyConfig())
Expand Down Expand Up @@ -132,54 +131,46 @@ class ParserStub:


def test_without_sql_comment_plain():

line = "SELECT * FROM author"
assert without_sql_comment(parser=parser_stub, line=line) == line


def test_without_sql_comment_with_arg():

line = "--file moo.txt --persist SELECT * FROM author"
assert without_sql_comment(parser=parser_stub, line=line) == line


def test_without_sql_comment_with_comment():

line = "SELECT * FROM author -- uff da"
expected = "SELECT * FROM author"
assert without_sql_comment(parser=parser_stub, line=line) == expected


def test_without_sql_comment_with_arg_and_comment():

line = "--file moo.txt --persist SELECT * FROM author -- uff da"
expected = "--file moo.txt --persist SELECT * FROM author"
assert without_sql_comment(parser=parser_stub, line=line) == expected


def test_without_sql_comment_unspaced_comment():

line = "SELECT * FROM author --uff da"
expected = "SELECT * FROM author"
assert without_sql_comment(parser=parser_stub, line=line) == expected


def test_without_sql_comment_dashes_in_string():

line = "SELECT '--very --confusing' FROM author -- uff da"
expected = "SELECT '--very --confusing' FROM author"
assert without_sql_comment(parser=parser_stub, line=line) == expected


def test_without_sql_comment_with_arg_and_leading_comment():

line = "--file moo.txt --persist --comment, not arg"
expected = "--file moo.txt --persist"
assert without_sql_comment(parser=parser_stub, line=line) == expected


def test_without_sql_persist():

line = "--persist my_table --uff da"
expected = "--persist my_table"
assert without_sql_comment(parser=parser_stub, line=line) == expected
Expand Down Expand Up @@ -224,7 +215,6 @@ def complete_with_defaults(mapping):
],
)
def test_magic_args(ip, line, expected):

sql_line = ip.magics_manager.lsmagic()["line"]["sql"]

args = magic_args(sql_line, line)
Expand Down

0 comments on commit bed8b36

Please sign in to comment.