From aa360f9505c7b0e3ad8968cf38b2ccfc9df9d4a1 Mon Sep 17 00:00:00 2001
From: Anirudh Iyer <69951190+AnirudhVIyer@users.noreply.github.com>
Date: Fri, 23 Jun 2023 19:40:03 -0400
Subject: [PATCH] modified TableDescription formatting (#616)
---
CHANGELOG.md | 2 +-
src/sql/inspect.py | 242 +++++++++++++++---
.../integration/test_generic_db_operations.py | 67 ++++-
src/tests/test_magic_cmd.py | 168 ++++++++++--
4 files changed, 409 insertions(+), 70 deletions(-)
diff --git a/CHANGELOG.md b/CHANGELOG.md
index 57cf4740b..70bae7a1f 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -1,8 +1,8 @@
# CHANGELOG
## 0.7.10dev
+* [Feature] Modified `TableDescription` to add styling, generate messages and format the calculated outputs (#459)
* [Feature] Support flexible spacing `myvar=<<` operator ([#525](https://github.com/ploomber/jupysql/issues/525))
-
* [Doc] Modified integrations content to ensure they're all consistent (#523)
* [Doc] Document --persist-replace in API section (#539)
* [Fix] Fixed CI issue by updating `invalid_connection_string_duckdb` in `test_magic.py` (#631)
diff --git a/src/sql/inspect.py b/src/sql/inspect.py
index efc6e0b75..f4455370c 100644
--- a/src/sql/inspect.py
+++ b/src/sql/inspect.py
@@ -8,6 +8,7 @@
import math
from sql import util
from IPython.core.display import HTML
+import uuid
def _get_inspector(conn):
@@ -77,6 +78,94 @@ def _get_row_with_most_keys(rows):
return list(rows[max_idx])
+def _is_numeric(value):
+ """Check if a column has numeric and not categorical datatype"""
+ try:
+ if isinstance(value, bool):
+ return False
+ float(value) # Try to convert the value to float
+ return True
+ except (TypeError, ValueError):
+ return False
+
+
+def _is_numeric_as_str(column, value):
+ """Check if a column contains numerical data stored as `str`"""
+ try:
+ if isinstance(value, str) and _is_numeric(value):
+ return True
+ return False
+ except ValueError:
+ pass
+
+
+def _generate_column_styles(
+ column_indices, unique_id, background_color="#FFFFCC", text_color="black"
+):
+ """
+ Generate CSS styles to change the background-color of all columns
+ with data-type mismatch.
+
+ Parameters
+ ----------
+ column_indices (list): List of column indices with data-type mismatch.
+ unique_id (str): Unique ID for the current table.
+ background_color (str, optional): Background color for the mismatched columns.
+ text_color (str, optional): Text color for the mismatched columns.
+
+ Returns:
+ str: HTML style tags containing the CSS styles for the mismatched columns.
+ """
+
+ styles = ""
+ for index in column_indices:
+ styles = f"""{styles}
+ #profile-table-{unique_id} td:nth-child({index + 1}) {{
+ background-color: {background_color};
+ color: {text_color};
+ }}
+ """
+ return f""
+
+
+def _generate_message(column_indices, columns):
+ """Generate a message indicating all columns with a datatype mismatch"""
+ message = "Columns "
+ for c in column_indices:
+ col = columns[c - 1]
+ message = f"{message}{col}
"
+ message = (
+ f"{message} have a datatype mismatch -> numeric values stored as a string."
+ )
+ message = f"{message}
Cannot calculate mean/min/max/std/percentiles"
+ return message
+
+
+def _assign_column_specific_stats(col_stats, is_numeric):
+ """
+ Assign NaN values to categorical/numerical specific statistic.
+
+ Parameters
+ ----------
+ col_stats (dict): Dictionary containing column statistics.
+ is_numeric (bool): Flag indicating whether the column is numeric or not.
+
+ Returns:
+ dict: Updated col_stats dictionary.
+ """
+ categorical_stats = ["top", "freq"]
+ numerical_stats = ["mean", "min", "max", "std", "25%", "50%", "75%"]
+
+ if is_numeric:
+ for stat in categorical_stats:
+ col_stats[stat] = math.nan
+ else:
+ for stat in numerical_stats:
+ col_stats[stat] = math.nan
+
+ return col_stats
+
+
@modify_exceptions
class Columns(DatabaseInspection):
"""
@@ -108,27 +197,36 @@ def __init__(self, name, schema, conn=None) -> None:
@modify_exceptions
class TableDescription(DatabaseInspection):
"""
- Generates descriptive statistics.
+ Generates descriptive statistics.
+
+ --------------------------------------
+ Descriptive statistics are:
+
+ Count - Number of all not None values
- Descriptive statistics are:
+ Mean - Mean of the values
- Count - Number of all not None values
+ Max - Maximum of the values in the object.
- Mean - Mean of the values
+ Min - Minimum of the values in the object.
- Max - Maximum of the values in the object.
+ STD - Standard deviation of the observations
- Min - Minimum of the values in the object.
+ 25h, 50h and 75h percentiles
- STD - Standard deviation of the observations
+ Unique - Number of not None unique values
- 25h, 50h and 75h percentiles
+ Top - The most frequent value
- Unique - Number of not None unique values
+ Freq - Frequency of the top value
- Top - The most frequent value
+ ------------------------------------------
+ Following statistics will be calculated for :-
- Freq - Frequency of the top value
+ Categorical columns - [Count, Unique, Top, Freq]
+
+ Numerical columns - [Count, Unique, Mean, Max, Min,
+ STD, 25h, 50h and 75h percentiles]
"""
@@ -141,7 +239,6 @@ def __init__(self, table_name, schema=None) -> None:
columns_query_result = sql.run.raw_run(
Connection.current, f"SELECT * FROM {table_name} WHERE 1=0"
)
-
if Connection.is_custom_connection():
columns = [i[0] for i in columns_query_result.description]
else:
@@ -149,10 +246,28 @@ def __init__(self, table_name, schema=None) -> None:
table_stats = dict({})
columns_to_include_in_report = set()
+ columns_with_styles = []
+ message_check = False
- for column in columns:
+ for i, column in enumerate(columns):
table_stats[column] = dict()
+ # check the datatype of a column
+ try:
+ result = sql.run.raw_run(
+ Connection.current, f"""SELECT {column} FROM {table_name} LIMIT 1"""
+ ).fetchone()
+
+ value = result[0]
+ is_numeric = isinstance(value, (int, float)) or (
+ isinstance(value, str) and _is_numeric(value)
+ )
+ except ValueError:
+ is_numeric = True
+
+ if _is_numeric_as_str(column, value):
+ columns_with_styles.append(i + 1)
+ message_check = True
# Note: index is reserved word in sqlite
try:
result_col_freq_values = sql.run.raw_run(
@@ -183,10 +298,12 @@ def __init__(self, table_name, schema=None) -> None:
""",
).fetchall()
- table_stats[column]["min"] = result_value_values[0][0]
- table_stats[column]["max"] = result_value_values[0][1]
+ columns_to_include_in_report.update(["count", "min", "max"])
table_stats[column]["count"] = result_value_values[0][2]
+ table_stats[column]["min"] = round(result_value_values[0][0], 4)
+ table_stats[column]["max"] = round(result_value_values[0][1], 4)
+
columns_to_include_in_report.update(["count", "min", "max"])
except Exception:
@@ -204,9 +321,7 @@ def __init__(self, table_name, schema=None) -> None:
""",
).fetchall()
table_stats[column]["unique"] = result_value_values[0][0]
-
columns_to_include_in_report.update(["unique"])
-
except Exception:
pass
@@ -220,8 +335,8 @@ def __init__(self, table_name, schema=None) -> None:
""",
).fetchall()
- table_stats[column]["mean"] = float(results_avg[0][0])
columns_to_include_in_report.update(["mean"])
+ table_stats[column]["mean"] = format(float(results_avg[0][0]), ".4f")
except Exception:
table_stats[column]["mean"] = math.nan
@@ -246,11 +361,10 @@ def __init__(self, table_name, schema=None) -> None:
""",
).fetchall()
+ columns_to_include_in_report.update(special_numeric_keys)
for i, key in enumerate(special_numeric_keys):
# r_key = f'key_{key.replace("%", "")}'
- table_stats[column][key] = float(result[0][i])
-
- columns_to_include_in_report.update(special_numeric_keys)
+ table_stats[column][key] = format(float(result[0][i]), ".4f")
except TypeError:
# for non numeric values
@@ -268,22 +382,73 @@ def __init__(self, table_name, schema=None) -> None:
# We ignore the cell stats for such case.
pass
+ table_stats[column] = _assign_column_specific_stats(
+ table_stats[column], is_numeric
+ )
+
self._table = PrettyTable()
self._table.field_names = [" "] + list(table_stats.keys())
- rows = list(columns_to_include_in_report)
- rows.sort(reverse=True)
- for row in rows:
- values = [row]
- for column in table_stats:
- if row in table_stats[column]:
- value = table_stats[column][row]
- else:
- value = ""
- value = util.convert_to_scientific(value)
- values.append(value)
+ custom_order = [
+ "count",
+ "unique",
+ "top",
+ "freq",
+ "mean",
+ "std",
+ "min",
+ "25%",
+ "50%",
+ "75%",
+ "max",
+ ]
+
+ for row in custom_order:
+ if row.lower() in [r.lower() for r in columns_to_include_in_report]:
+ values = [row]
+ for column in table_stats:
+ if row in table_stats[column]:
+ value = table_stats[column][row]
+ else:
+ value = ""
+ # value = util.convert_to_scientific(value)
+ values.append(value)
+
+ self._table.add_row(values)
+
+ unique_id = str(uuid.uuid4()).replace("-", "")
+ column_styles = _generate_column_styles(columns_with_styles, unique_id)
+
+ if message_check:
+ message_content = _generate_message(columns_with_styles, list(columns))
+ warning_background = "#FFFFCC"
+ warning_title = "Warning: "
+ else:
+ message_content = ""
+ warning_background = "white"
+ warning_title = ""
+
+ database = Connection.current.url
+ db_driver = Connection.current._get_curr_sqlalchemy_connection_info()["driver"]
+ if "duckdb" in database:
+ db_message = ""
+ else:
+ db_message = f"""Following statistics are not available in
+ {db_driver}: STD, 25%, 50%, 75%"""
+
+ db_html = (
+ f"
age
gender_1
have a datatype mismatch"
+ in stats_table_html
+ )
@pytest.mark.parametrize(
- "ip_with_dynamic_db, table, table_columns, expected",
+ "ip_with_dynamic_db, table, table_columns, expected, message",
[
(
"ip_with_postgreSQL",
@@ -422,6 +458,7 @@ def test_sql_cmd_magic_dos(ip_with_dynamic_db, request, capsys):
"50%": [22.0, math.nan],
"75%": [33.0, math.nan],
},
+ None,
),
pytest.param(
"ip_with_mySQL",
@@ -436,6 +473,7 @@ def test_sql_cmd_magic_dos(ip_with_dynamic_db, request, capsys):
"freq": [15],
"top": ["Kevin Kelly"],
},
+ "Following statistics are not available in",
marks=pytest.mark.xfail(
reason="Need to get column names from table with a different query"
),
@@ -453,6 +491,7 @@ def test_sql_cmd_magic_dos(ip_with_dynamic_db, request, capsys):
"freq": [15],
"top": ["Kevin Kelly"],
},
+ "Following statistics are not available in",
marks=pytest.mark.xfail(
reason="Need to get column names from table with a different query"
),
@@ -470,6 +509,7 @@ def test_sql_cmd_magic_dos(ip_with_dynamic_db, request, capsys):
"freq": [15],
"top": ["Kevin Kelly"],
},
+ "Following statistics are not available in",
),
(
"ip_with_duckDB",
@@ -488,18 +528,21 @@ def test_sql_cmd_magic_dos(ip_with_dynamic_db, request, capsys):
"50%": [22.0, math.nan],
"75%": [33.0, math.nan],
},
+ None,
),
(
"ip_with_MSSQL",
"taxi",
["taxi_driver_name"],
{"unique": [3], "min": ["Eric Ken"], "max": ["Kevin Kelly"], "count": [45]},
+ "Following statistics are not available in",
),
pytest.param(
"ip_with_Snowflake",
"taxi",
["taxi_driver_name"],
{},
+ None,
marks=pytest.mark.xfail(
reason="Something wrong with test_profile_query in snowflake"
),
@@ -509,6 +552,7 @@ def test_sql_cmd_magic_dos(ip_with_dynamic_db, request, capsys):
"taxi",
["taxi_driver_name"],
{},
+ None,
marks=pytest.mark.xfail(
reason="Something wrong with test_profile_query in snowflake"
),
@@ -516,7 +560,13 @@ def test_sql_cmd_magic_dos(ip_with_dynamic_db, request, capsys):
],
)
def test_profile_query(
- request, ip_with_dynamic_db, table, table_columns, expected, test_table_name_dict
+ request,
+ ip_with_dynamic_db,
+ table,
+ table_columns,
+ expected,
+ test_table_name_dict,
+ message,
):
pytest.skip("Skip on unclosed session issue")
ip_with_dynamic_db = request.getfixturevalue(ip_with_dynamic_db)
@@ -528,7 +578,7 @@ def test_profile_query(
).result
stats_table = out._table
-
+ stats_table_html = out._table_html
assert len(stats_table.rows) == len(expected)
for row in stats_table:
@@ -542,6 +592,9 @@ def test_profile_query(
assert criteria in expected
assert cell_value == str(expected[criteria][i])
+ if message:
+ assert message in stats_table_html
+
@pytest.mark.parametrize(
"table",
diff --git a/src/tests/test_magic_cmd.py b/src/tests/test_magic_cmd.py
index 1a56e9b4c..d9b7ef97b 100644
--- a/src/tests/test_magic_cmd.py
+++ b/src/tests/test_magic_cmd.py
@@ -1,5 +1,5 @@
import sys
-
+import math
import pytest
from IPython.core.error import UsageError
from pathlib import Path
@@ -7,6 +7,7 @@
from sqlalchemy import create_engine
from sql.connection import Connection
from sql.store import store
+from sql.inspect import _is_numeric
VALID_COMMANDS_MESSAGE = (
@@ -14,6 +15,21 @@
)
+def _get_row_string(row, column_name):
+ """
+ Helper function to retrieve the string value of a specific column in a table row.
+
+ Parameters
+ ----------
+ row: PrettyTable row object.
+ column_name: Name of the column.
+
+ Returns:
+ String value of the specified column in the row.
+ """
+ return row.get_string(fields=[column_name], border=False, header=False).strip()
+
+
@pytest.fixture
def ip_snippets(ip):
for key in list(store):
@@ -168,12 +184,12 @@ def test_table_profile(ip, tmp_empty):
expected = {
"count": [8, 8, 8, 8],
- "mean": [12.2165, "6.875e-01", 88.75, 0.0],
- "min": [10.532, 0.1, 82, ""],
- "max": [14.44, 2.48, 98, "c"],
+ "mean": ["12.2165", "0.6875", "88.7500", math.nan],
+ "min": [10.532, 0.1, 82, math.nan],
+ "max": [14.44, 2.48, 98, math.nan],
"unique": [8, 7, 8, 5],
- "freq": [1, 2, 1, 4],
- "top": [14.44, 0.2, 98, "a"],
+ "freq": [math.nan, math.nan, math.nan, 4],
+ "top": [math.nan, math.nan, math.nan, "a"],
}
out = ip.run_cell("%sqlcmd profile -t numbers").result
@@ -183,21 +199,70 @@ def test_table_profile(ip, tmp_empty):
assert len(stats_table.rows) == len(expected)
for row in stats_table:
- criteria = row.get_string(fields=[" "], border=False).strip()
+ profile_metric = _get_row_string(row, " ")
+ rating = _get_row_string(row, "rating")
+ price = _get_row_string(row, "price")
+ number = _get_row_string(row, "number")
+ word = _get_row_string(row, "word")
+
+ assert profile_metric in expected
+ assert rating == str(expected[profile_metric][0])
+ assert price == str(expected[profile_metric][1])
+ assert number == str(expected[profile_metric][2])
+ assert word == str(expected[profile_metric][3])
- rating = row.get_string(fields=["rating"], border=False, header=False).strip()
+ # Test sticky column style was injected
+ assert "position: sticky;" in out._table_html
- price = row.get_string(fields=["price"], border=False, header=False).strip()
- number = row.get_string(fields=["number"], border=False, header=False).strip()
+def test_table_profile_with_stdev(ip, tmp_empty):
+ ip.run_cell(
+ """
+ %%sql duckdb://
+ CREATE TABLE numbers (rating float, price float, number int, word varchar(50));
+ INSERT INTO numbers VALUES (14.44, 2.48, 82, 'a');
+ INSERT INTO numbers VALUES (13.13, 1.50, 93, 'b');
+ INSERT INTO numbers VALUES (12.59, 0.20, 98, 'a');
+ INSERT INTO numbers VALUES (11.54, 0.41, 89, 'a');
+ INSERT INTO numbers VALUES (10.532, 0.1, 88, 'c');
+ INSERT INTO numbers VALUES (11.5, 0.2, 84, ' ');
+ INSERT INTO numbers VALUES (11.1, 0.3, 90, 'a');
+ INSERT INTO numbers VALUES (12.9, 0.31, 86, '');
+ """
+ )
- word = row.get_string(fields=["word"], border=False, header=False).strip()
+ expected = {
+ "count": [8, 8, 8, 8],
+ "mean": ["12.2165", "0.6875", "88.7500", math.nan],
+ "min": [10.532, 0.1, 82, math.nan],
+ "max": [14.44, 2.48, 98, math.nan],
+ "unique": [8, 7, 8, 5],
+ "freq": [math.nan, math.nan, math.nan, 4],
+ "top": [math.nan, math.nan, math.nan, "a"],
+ "std": ["1.1958", "0.7956", "4.7631", math.nan],
+ "25%": ["11.1000", "0.2000", "84.0000", math.nan],
+ "50%": ["11.5400", "0.3000", "88.0000", math.nan],
+ "75%": ["12.9000", "0.4100", "90.0000", math.nan],
+ }
- assert criteria in expected
- assert rating == str(expected[criteria][0])
- assert price == str(expected[criteria][1])
- assert number == str(expected[criteria][2])
- assert word == str(expected[criteria][3])
+ out = ip.run_cell("%sqlcmd profile -t numbers").result
+
+ stats_table = out._table
+
+ assert len(stats_table.rows) == len(expected)
+
+ for row in stats_table:
+ profile_metric = _get_row_string(row, " ")
+ rating = _get_row_string(row, "rating")
+ price = _get_row_string(row, "price")
+ number = _get_row_string(row, "number")
+ word = _get_row_string(row, "word")
+
+ assert profile_metric in expected
+ assert rating == str(expected[profile_metric][0])
+ assert price == str(expected[profile_metric][1])
+ assert number == str(expected[profile_metric][2])
+ assert word == str(expected[profile_metric][3])
# Test sticky column style was injected
assert "position: sticky;" in out._table_html
@@ -227,14 +292,14 @@ def test_table_schema_profile(ip, tmp_empty):
)
expected = {
- "count": [3],
- "mean": [22.0],
- "min": [11.0],
- "max": [33.0],
- "std": [11.0],
- "unique": [3],
- "freq": [1],
- "top": [33.0],
+ "count": ["3"],
+ "mean": ["22.0000"],
+ "min": ["11.0"],
+ "max": ["33.0"],
+ "std": ["11.0000"],
+ "unique": ["3"],
+ "freq": [math.nan],
+ "top": [math.nan],
}
out = ip.run_cell("%sqlcmd profile -t t --schema b_schema").result
@@ -242,12 +307,61 @@ def test_table_schema_profile(ip, tmp_empty):
stats_table = out._table
for row in stats_table:
- criteria = row.get_string(fields=[" "], border=False).strip()
+ profile_metric = _get_row_string(row, " ")
cell = row.get_string(fields=["n"], border=False, header=False).strip()
- if criteria in expected:
- assert cell == str(expected[criteria][0])
+ if profile_metric in expected:
+ assert cell == str(expected[profile_metric][0])
+
+
+def test_table_profile_warnings_styles(ip, tmp_empty):
+ ip.run_cell(
+ """
+ %%sql sqlite://
+ CREATE TABLE numbers (rating float,price varchar(50),number int,word varchar(50));
+ INSERT INTO numbers VALUES (14.44, '2.48', 82, 'a');
+ INSERT INTO numbers VALUES (13.13, '1.50', 93, 'b');
+ """
+ )
+ out = ip.run_cell("%sqlcmd profile -t numbers").result
+ stats_table_html = out._table_html
+ assert "Columns price
have a datatype mismatch" in stats_table_html
+ assert "td:nth-child(3)" in stats_table_html
+ assert "Following statistics are not available in" in stats_table_html
+
+
+def test_profile_is_numeric():
+ assert _is_numeric("123") is True
+ assert _is_numeric(None) is False
+ assert _is_numeric("abc") is False
+ assert _is_numeric("45.6") is True
+ assert _is_numeric(100) is True
+ assert _is_numeric(True) is False
+ assert _is_numeric("NaN") is True
+ assert _is_numeric(math.nan) is True
+
+
+def test_table_profile_is_numeric(ip, tmp_empty):
+ ip.run_cell(
+ """
+ %%sql sqlite://
+ CREATE TABLE people (name varchar(50),age varchar(50),number int,
+ country varchar(50),gender_1 varchar(50), gender_2 varchar(50));
+ INSERT INTO people VALUES ('joe', '48', 82, 'usa', '0', 'male');
+ INSERT INTO people VALUES ('paula', '50', 93, 'uk', '1', 'female');
+ """
+ )
+ out = ip.run_cell("%sqlcmd profile -t people").result
+ stats_table_html = out._table_html
+ assert "td:nth-child(3)" in stats_table_html
+ assert "td:nth-child(6)" in stats_table_html
+ assert "td:nth-child(7)" not in stats_table_html
+ assert "td:nth-child(4)" not in stats_table_html
+ assert (
+ "Columns age
gender_1
have a datatype mismatch"
+ in stats_table_html
+ )
def test_table_profile_store(ip, tmp_empty):