From aa360f9505c7b0e3ad8968cf38b2ccfc9df9d4a1 Mon Sep 17 00:00:00 2001 From: Anirudh Iyer <69951190+AnirudhVIyer@users.noreply.github.com> Date: Fri, 23 Jun 2023 19:40:03 -0400 Subject: [PATCH] modified TableDescription formatting (#616) --- CHANGELOG.md | 2 +- src/sql/inspect.py | 242 +++++++++++++++--- .../integration/test_generic_db_operations.py | 67 ++++- src/tests/test_magic_cmd.py | 168 ++++++++++-- 4 files changed, 409 insertions(+), 70 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 57cf4740b..70bae7a1f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,8 +1,8 @@ # CHANGELOG ## 0.7.10dev +* [Feature] Modified `TableDescription` to add styling, generate messages and format the calculated outputs (#459) * [Feature] Support flexible spacing `myvar=<<` operator ([#525](https://github.com/ploomber/jupysql/issues/525)) - * [Doc] Modified integrations content to ensure they're all consistent (#523) * [Doc] Document --persist-replace in API section (#539) * [Fix] Fixed CI issue by updating `invalid_connection_string_duckdb` in `test_magic.py` (#631) diff --git a/src/sql/inspect.py b/src/sql/inspect.py index efc6e0b75..f4455370c 100644 --- a/src/sql/inspect.py +++ b/src/sql/inspect.py @@ -8,6 +8,7 @@ import math from sql import util from IPython.core.display import HTML +import uuid def _get_inspector(conn): @@ -77,6 +78,94 @@ def _get_row_with_most_keys(rows): return list(rows[max_idx]) +def _is_numeric(value): + """Check if a column has numeric and not categorical datatype""" + try: + if isinstance(value, bool): + return False + float(value) # Try to convert the value to float + return True + except (TypeError, ValueError): + return False + + +def _is_numeric_as_str(column, value): + """Check if a column contains numerical data stored as `str`""" + try: + if isinstance(value, str) and _is_numeric(value): + return True + return False + except ValueError: + pass + + +def _generate_column_styles( + column_indices, unique_id, background_color="#FFFFCC", text_color="black" +): + """ + Generate CSS styles to change the background-color of all columns + with data-type mismatch. + + Parameters + ---------- + column_indices (list): List of column indices with data-type mismatch. + unique_id (str): Unique ID for the current table. + background_color (str, optional): Background color for the mismatched columns. + text_color (str, optional): Text color for the mismatched columns. + + Returns: + str: HTML style tags containing the CSS styles for the mismatched columns. + """ + + styles = "" + for index in column_indices: + styles = f"""{styles} + #profile-table-{unique_id} td:nth-child({index + 1}) {{ + background-color: {background_color}; + color: {text_color}; + }} + """ + return f"" + + +def _generate_message(column_indices, columns): + """Generate a message indicating all columns with a datatype mismatch""" + message = "Columns " + for c in column_indices: + col = columns[c - 1] + message = f"{message}{col}" + message = ( + f"{message} have a datatype mismatch -> numeric values stored as a string." + ) + message = f"{message}
Cannot calculate mean/min/max/std/percentiles" + return message + + +def _assign_column_specific_stats(col_stats, is_numeric): + """ + Assign NaN values to categorical/numerical specific statistic. + + Parameters + ---------- + col_stats (dict): Dictionary containing column statistics. + is_numeric (bool): Flag indicating whether the column is numeric or not. + + Returns: + dict: Updated col_stats dictionary. + """ + categorical_stats = ["top", "freq"] + numerical_stats = ["mean", "min", "max", "std", "25%", "50%", "75%"] + + if is_numeric: + for stat in categorical_stats: + col_stats[stat] = math.nan + else: + for stat in numerical_stats: + col_stats[stat] = math.nan + + return col_stats + + @modify_exceptions class Columns(DatabaseInspection): """ @@ -108,27 +197,36 @@ def __init__(self, name, schema, conn=None) -> None: @modify_exceptions class TableDescription(DatabaseInspection): """ - Generates descriptive statistics. + Generates descriptive statistics. + + -------------------------------------- + Descriptive statistics are: + + Count - Number of all not None values - Descriptive statistics are: + Mean - Mean of the values - Count - Number of all not None values + Max - Maximum of the values in the object. - Mean - Mean of the values + Min - Minimum of the values in the object. - Max - Maximum of the values in the object. + STD - Standard deviation of the observations - Min - Minimum of the values in the object. + 25h, 50h and 75h percentiles - STD - Standard deviation of the observations + Unique - Number of not None unique values - 25h, 50h and 75h percentiles + Top - The most frequent value - Unique - Number of not None unique values + Freq - Frequency of the top value - Top - The most frequent value + ------------------------------------------ + Following statistics will be calculated for :- - Freq - Frequency of the top value + Categorical columns - [Count, Unique, Top, Freq] + + Numerical columns - [Count, Unique, Mean, Max, Min, + STD, 25h, 50h and 75h percentiles] """ @@ -141,7 +239,6 @@ def __init__(self, table_name, schema=None) -> None: columns_query_result = sql.run.raw_run( Connection.current, f"SELECT * FROM {table_name} WHERE 1=0" ) - if Connection.is_custom_connection(): columns = [i[0] for i in columns_query_result.description] else: @@ -149,10 +246,28 @@ def __init__(self, table_name, schema=None) -> None: table_stats = dict({}) columns_to_include_in_report = set() + columns_with_styles = [] + message_check = False - for column in columns: + for i, column in enumerate(columns): table_stats[column] = dict() + # check the datatype of a column + try: + result = sql.run.raw_run( + Connection.current, f"""SELECT {column} FROM {table_name} LIMIT 1""" + ).fetchone() + + value = result[0] + is_numeric = isinstance(value, (int, float)) or ( + isinstance(value, str) and _is_numeric(value) + ) + except ValueError: + is_numeric = True + + if _is_numeric_as_str(column, value): + columns_with_styles.append(i + 1) + message_check = True # Note: index is reserved word in sqlite try: result_col_freq_values = sql.run.raw_run( @@ -183,10 +298,12 @@ def __init__(self, table_name, schema=None) -> None: """, ).fetchall() - table_stats[column]["min"] = result_value_values[0][0] - table_stats[column]["max"] = result_value_values[0][1] + columns_to_include_in_report.update(["count", "min", "max"]) table_stats[column]["count"] = result_value_values[0][2] + table_stats[column]["min"] = round(result_value_values[0][0], 4) + table_stats[column]["max"] = round(result_value_values[0][1], 4) + columns_to_include_in_report.update(["count", "min", "max"]) except Exception: @@ -204,9 +321,7 @@ def __init__(self, table_name, schema=None) -> None: """, ).fetchall() table_stats[column]["unique"] = result_value_values[0][0] - columns_to_include_in_report.update(["unique"]) - except Exception: pass @@ -220,8 +335,8 @@ def __init__(self, table_name, schema=None) -> None: """, ).fetchall() - table_stats[column]["mean"] = float(results_avg[0][0]) columns_to_include_in_report.update(["mean"]) + table_stats[column]["mean"] = format(float(results_avg[0][0]), ".4f") except Exception: table_stats[column]["mean"] = math.nan @@ -246,11 +361,10 @@ def __init__(self, table_name, schema=None) -> None: """, ).fetchall() + columns_to_include_in_report.update(special_numeric_keys) for i, key in enumerate(special_numeric_keys): # r_key = f'key_{key.replace("%", "")}' - table_stats[column][key] = float(result[0][i]) - - columns_to_include_in_report.update(special_numeric_keys) + table_stats[column][key] = format(float(result[0][i]), ".4f") except TypeError: # for non numeric values @@ -268,22 +382,73 @@ def __init__(self, table_name, schema=None) -> None: # We ignore the cell stats for such case. pass + table_stats[column] = _assign_column_specific_stats( + table_stats[column], is_numeric + ) + self._table = PrettyTable() self._table.field_names = [" "] + list(table_stats.keys()) - rows = list(columns_to_include_in_report) - rows.sort(reverse=True) - for row in rows: - values = [row] - for column in table_stats: - if row in table_stats[column]: - value = table_stats[column][row] - else: - value = "" - value = util.convert_to_scientific(value) - values.append(value) + custom_order = [ + "count", + "unique", + "top", + "freq", + "mean", + "std", + "min", + "25%", + "50%", + "75%", + "max", + ] + + for row in custom_order: + if row.lower() in [r.lower() for r in columns_to_include_in_report]: + values = [row] + for column in table_stats: + if row in table_stats[column]: + value = table_stats[column][row] + else: + value = "" + # value = util.convert_to_scientific(value) + values.append(value) + + self._table.add_row(values) + + unique_id = str(uuid.uuid4()).replace("-", "") + column_styles = _generate_column_styles(columns_with_styles, unique_id) + + if message_check: + message_content = _generate_message(columns_with_styles, list(columns)) + warning_background = "#FFFFCC" + warning_title = "Warning: " + else: + message_content = "" + warning_background = "white" + warning_title = "" + + database = Connection.current.url + db_driver = Connection.current._get_curr_sqlalchemy_connection_info()["driver"] + if "duckdb" in database: + db_message = "" + else: + db_message = f"""Following statistics are not available in + {db_driver}: STD, 25%, 50%, 75%""" + + db_html = ( + f"
" + f" {db_message}" + "
" + ) - self._table.add_row(values) + message_html = ( + f"
" + f"{warning_title} {message_content}" + "
" + ) # Inject css to html to make first column sticky sticky_column_css = """""" self._table_html = HTML( - sticky_column_css - + self._table.get_html_string(attributes={"id": "profile-table"}) + db_html + + sticky_column_css + + column_styles + + self._table.get_html_string( + attributes={"id": f"profile-table-{unique_id}"} + ) + + message_html ).__html__() self._table_txt = self._table.get_string() diff --git a/src/tests/integration/test_generic_db_operations.py b/src/tests/integration/test_generic_db_operations.py index 6c25a9110..4d23f6064 100644 --- a/src/tests/integration/test_generic_db_operations.py +++ b/src/tests/integration/test_generic_db_operations.py @@ -392,18 +392,54 @@ def test_sql_cmd_magic_dos(ip_with_dynamic_db, request, capsys): """ ) + +@pytest.mark.parametrize( + "ip_with_dynamic_db", + [ + ("ip_with_postgreSQL"), + ("ip_with_mySQL"), + ("ip_with_mariaDB"), + ("ip_with_SQLite"), + ("ip_with_duckDB"), + ("ip_with_MSSQL"), + pytest.param( + "ip_with_Snowflake", + marks=pytest.mark.xfail( + reason="Something wrong with test_sql_cmd_magic_dos in snowflake" + ), + ), + ("ip_with_oracle"), + ], +) +def test_profile_data_mismatch(ip_with_dynamic_db, request, capsys): + ip_with_dynamic_db = request.getfixturevalue(ip_with_dynamic_db) + ip_with_dynamic_db.run_cell( - "%sqlcmd test --table test_numbers --column value --greater-or-equal 3" + """ + %%sql sqlite:// + CREATE TABLE people (name varchar(50),age varchar(50),number int, + country varchar(50),gender_1 varchar(50), gender_2 varchar(50)); + INSERT INTO people VALUES ('joe', '48', 82, 'usa', '0', 'male'); + INSERT INTO people VALUES ('paula', '50', 93, 'uk', '1', 'female'); + """ ) - _out = capsys.readouterr() + out = ip_with_dynamic_db.run_cell("%sqlcmd profile -t people").result - assert "greater_or_equal" in _out.out - assert "0" in _out.out + stats_table_html = out._table_html + + assert "td:nth-child(3)" in stats_table_html + assert "td:nth-child(6)" in stats_table_html + assert "td:nth-child(7)" not in stats_table_html + assert "td:nth-child(4)" not in stats_table_html + assert ( + "Columns agegender_1 have a datatype mismatch" + in stats_table_html + ) @pytest.mark.parametrize( - "ip_with_dynamic_db, table, table_columns, expected", + "ip_with_dynamic_db, table, table_columns, expected, message", [ ( "ip_with_postgreSQL", @@ -422,6 +458,7 @@ def test_sql_cmd_magic_dos(ip_with_dynamic_db, request, capsys): "50%": [22.0, math.nan], "75%": [33.0, math.nan], }, + None, ), pytest.param( "ip_with_mySQL", @@ -436,6 +473,7 @@ def test_sql_cmd_magic_dos(ip_with_dynamic_db, request, capsys): "freq": [15], "top": ["Kevin Kelly"], }, + "Following statistics are not available in", marks=pytest.mark.xfail( reason="Need to get column names from table with a different query" ), @@ -453,6 +491,7 @@ def test_sql_cmd_magic_dos(ip_with_dynamic_db, request, capsys): "freq": [15], "top": ["Kevin Kelly"], }, + "Following statistics are not available in", marks=pytest.mark.xfail( reason="Need to get column names from table with a different query" ), @@ -470,6 +509,7 @@ def test_sql_cmd_magic_dos(ip_with_dynamic_db, request, capsys): "freq": [15], "top": ["Kevin Kelly"], }, + "Following statistics are not available in", ), ( "ip_with_duckDB", @@ -488,18 +528,21 @@ def test_sql_cmd_magic_dos(ip_with_dynamic_db, request, capsys): "50%": [22.0, math.nan], "75%": [33.0, math.nan], }, + None, ), ( "ip_with_MSSQL", "taxi", ["taxi_driver_name"], {"unique": [3], "min": ["Eric Ken"], "max": ["Kevin Kelly"], "count": [45]}, + "Following statistics are not available in", ), pytest.param( "ip_with_Snowflake", "taxi", ["taxi_driver_name"], {}, + None, marks=pytest.mark.xfail( reason="Something wrong with test_profile_query in snowflake" ), @@ -509,6 +552,7 @@ def test_sql_cmd_magic_dos(ip_with_dynamic_db, request, capsys): "taxi", ["taxi_driver_name"], {}, + None, marks=pytest.mark.xfail( reason="Something wrong with test_profile_query in snowflake" ), @@ -516,7 +560,13 @@ def test_sql_cmd_magic_dos(ip_with_dynamic_db, request, capsys): ], ) def test_profile_query( - request, ip_with_dynamic_db, table, table_columns, expected, test_table_name_dict + request, + ip_with_dynamic_db, + table, + table_columns, + expected, + test_table_name_dict, + message, ): pytest.skip("Skip on unclosed session issue") ip_with_dynamic_db = request.getfixturevalue(ip_with_dynamic_db) @@ -528,7 +578,7 @@ def test_profile_query( ).result stats_table = out._table - + stats_table_html = out._table_html assert len(stats_table.rows) == len(expected) for row in stats_table: @@ -542,6 +592,9 @@ def test_profile_query( assert criteria in expected assert cell_value == str(expected[criteria][i]) + if message: + assert message in stats_table_html + @pytest.mark.parametrize( "table", diff --git a/src/tests/test_magic_cmd.py b/src/tests/test_magic_cmd.py index 1a56e9b4c..d9b7ef97b 100644 --- a/src/tests/test_magic_cmd.py +++ b/src/tests/test_magic_cmd.py @@ -1,5 +1,5 @@ import sys - +import math import pytest from IPython.core.error import UsageError from pathlib import Path @@ -7,6 +7,7 @@ from sqlalchemy import create_engine from sql.connection import Connection from sql.store import store +from sql.inspect import _is_numeric VALID_COMMANDS_MESSAGE = ( @@ -14,6 +15,21 @@ ) +def _get_row_string(row, column_name): + """ + Helper function to retrieve the string value of a specific column in a table row. + + Parameters + ---------- + row: PrettyTable row object. + column_name: Name of the column. + + Returns: + String value of the specified column in the row. + """ + return row.get_string(fields=[column_name], border=False, header=False).strip() + + @pytest.fixture def ip_snippets(ip): for key in list(store): @@ -168,12 +184,12 @@ def test_table_profile(ip, tmp_empty): expected = { "count": [8, 8, 8, 8], - "mean": [12.2165, "6.875e-01", 88.75, 0.0], - "min": [10.532, 0.1, 82, ""], - "max": [14.44, 2.48, 98, "c"], + "mean": ["12.2165", "0.6875", "88.7500", math.nan], + "min": [10.532, 0.1, 82, math.nan], + "max": [14.44, 2.48, 98, math.nan], "unique": [8, 7, 8, 5], - "freq": [1, 2, 1, 4], - "top": [14.44, 0.2, 98, "a"], + "freq": [math.nan, math.nan, math.nan, 4], + "top": [math.nan, math.nan, math.nan, "a"], } out = ip.run_cell("%sqlcmd profile -t numbers").result @@ -183,21 +199,70 @@ def test_table_profile(ip, tmp_empty): assert len(stats_table.rows) == len(expected) for row in stats_table: - criteria = row.get_string(fields=[" "], border=False).strip() + profile_metric = _get_row_string(row, " ") + rating = _get_row_string(row, "rating") + price = _get_row_string(row, "price") + number = _get_row_string(row, "number") + word = _get_row_string(row, "word") + + assert profile_metric in expected + assert rating == str(expected[profile_metric][0]) + assert price == str(expected[profile_metric][1]) + assert number == str(expected[profile_metric][2]) + assert word == str(expected[profile_metric][3]) - rating = row.get_string(fields=["rating"], border=False, header=False).strip() + # Test sticky column style was injected + assert "position: sticky;" in out._table_html - price = row.get_string(fields=["price"], border=False, header=False).strip() - number = row.get_string(fields=["number"], border=False, header=False).strip() +def test_table_profile_with_stdev(ip, tmp_empty): + ip.run_cell( + """ + %%sql duckdb:// + CREATE TABLE numbers (rating float, price float, number int, word varchar(50)); + INSERT INTO numbers VALUES (14.44, 2.48, 82, 'a'); + INSERT INTO numbers VALUES (13.13, 1.50, 93, 'b'); + INSERT INTO numbers VALUES (12.59, 0.20, 98, 'a'); + INSERT INTO numbers VALUES (11.54, 0.41, 89, 'a'); + INSERT INTO numbers VALUES (10.532, 0.1, 88, 'c'); + INSERT INTO numbers VALUES (11.5, 0.2, 84, ' '); + INSERT INTO numbers VALUES (11.1, 0.3, 90, 'a'); + INSERT INTO numbers VALUES (12.9, 0.31, 86, ''); + """ + ) - word = row.get_string(fields=["word"], border=False, header=False).strip() + expected = { + "count": [8, 8, 8, 8], + "mean": ["12.2165", "0.6875", "88.7500", math.nan], + "min": [10.532, 0.1, 82, math.nan], + "max": [14.44, 2.48, 98, math.nan], + "unique": [8, 7, 8, 5], + "freq": [math.nan, math.nan, math.nan, 4], + "top": [math.nan, math.nan, math.nan, "a"], + "std": ["1.1958", "0.7956", "4.7631", math.nan], + "25%": ["11.1000", "0.2000", "84.0000", math.nan], + "50%": ["11.5400", "0.3000", "88.0000", math.nan], + "75%": ["12.9000", "0.4100", "90.0000", math.nan], + } - assert criteria in expected - assert rating == str(expected[criteria][0]) - assert price == str(expected[criteria][1]) - assert number == str(expected[criteria][2]) - assert word == str(expected[criteria][3]) + out = ip.run_cell("%sqlcmd profile -t numbers").result + + stats_table = out._table + + assert len(stats_table.rows) == len(expected) + + for row in stats_table: + profile_metric = _get_row_string(row, " ") + rating = _get_row_string(row, "rating") + price = _get_row_string(row, "price") + number = _get_row_string(row, "number") + word = _get_row_string(row, "word") + + assert profile_metric in expected + assert rating == str(expected[profile_metric][0]) + assert price == str(expected[profile_metric][1]) + assert number == str(expected[profile_metric][2]) + assert word == str(expected[profile_metric][3]) # Test sticky column style was injected assert "position: sticky;" in out._table_html @@ -227,14 +292,14 @@ def test_table_schema_profile(ip, tmp_empty): ) expected = { - "count": [3], - "mean": [22.0], - "min": [11.0], - "max": [33.0], - "std": [11.0], - "unique": [3], - "freq": [1], - "top": [33.0], + "count": ["3"], + "mean": ["22.0000"], + "min": ["11.0"], + "max": ["33.0"], + "std": ["11.0000"], + "unique": ["3"], + "freq": [math.nan], + "top": [math.nan], } out = ip.run_cell("%sqlcmd profile -t t --schema b_schema").result @@ -242,12 +307,61 @@ def test_table_schema_profile(ip, tmp_empty): stats_table = out._table for row in stats_table: - criteria = row.get_string(fields=[" "], border=False).strip() + profile_metric = _get_row_string(row, " ") cell = row.get_string(fields=["n"], border=False, header=False).strip() - if criteria in expected: - assert cell == str(expected[criteria][0]) + if profile_metric in expected: + assert cell == str(expected[profile_metric][0]) + + +def test_table_profile_warnings_styles(ip, tmp_empty): + ip.run_cell( + """ + %%sql sqlite:// + CREATE TABLE numbers (rating float,price varchar(50),number int,word varchar(50)); + INSERT INTO numbers VALUES (14.44, '2.48', 82, 'a'); + INSERT INTO numbers VALUES (13.13, '1.50', 93, 'b'); + """ + ) + out = ip.run_cell("%sqlcmd profile -t numbers").result + stats_table_html = out._table_html + assert "Columns price have a datatype mismatch" in stats_table_html + assert "td:nth-child(3)" in stats_table_html + assert "Following statistics are not available in" in stats_table_html + + +def test_profile_is_numeric(): + assert _is_numeric("123") is True + assert _is_numeric(None) is False + assert _is_numeric("abc") is False + assert _is_numeric("45.6") is True + assert _is_numeric(100) is True + assert _is_numeric(True) is False + assert _is_numeric("NaN") is True + assert _is_numeric(math.nan) is True + + +def test_table_profile_is_numeric(ip, tmp_empty): + ip.run_cell( + """ + %%sql sqlite:// + CREATE TABLE people (name varchar(50),age varchar(50),number int, + country varchar(50),gender_1 varchar(50), gender_2 varchar(50)); + INSERT INTO people VALUES ('joe', '48', 82, 'usa', '0', 'male'); + INSERT INTO people VALUES ('paula', '50', 93, 'uk', '1', 'female'); + """ + ) + out = ip.run_cell("%sqlcmd profile -t people").result + stats_table_html = out._table_html + assert "td:nth-child(3)" in stats_table_html + assert "td:nth-child(6)" in stats_table_html + assert "td:nth-child(7)" not in stats_table_html + assert "td:nth-child(4)" not in stats_table_html + assert ( + "Columns agegender_1 have a datatype mismatch" + in stats_table_html + ) def test_table_profile_store(ip, tmp_empty):