Skip to content

Commit

Permalink
Merge branch '518-refactor-magic_cmd' of https://github.com/AnirudhVI…
Browse files Browse the repository at this point in the history
…yer/jupysql into 518-refactor-magic_cmd

merge upstram
  • Loading branch information
AnirudhVIyer committed Jun 24, 2023
2 parents ec6b435 + 4e53224 commit 0efeb65
Show file tree
Hide file tree
Showing 4 changed files with 409 additions and 69 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

## 0.7.10dev

* [Feature] Modified `TableDescription` to add styling, generate messages and format the calculated outputs (#459)
* [Feature] Support flexible spacing `myvar=<<` operator ([#525](https://github.com/ploomber/jupysql/issues/525))
* [Doc] Modified integrations content to ensure they're all consistent (#523)
* [Doc] Document --persist-replace in API section (#539)
Expand Down
242 changes: 207 additions & 35 deletions src/sql/inspect.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import math
from sql import util
from IPython.core.display import HTML
import uuid


def _get_inspector(conn):
Expand Down Expand Up @@ -77,6 +78,94 @@ def _get_row_with_most_keys(rows):
return list(rows[max_idx])


def _is_numeric(value):
"""Check if a column has numeric and not categorical datatype"""
try:
if isinstance(value, bool):
return False
float(value) # Try to convert the value to float
return True
except (TypeError, ValueError):
return False


def _is_numeric_as_str(column, value):
"""Check if a column contains numerical data stored as `str`"""
try:
if isinstance(value, str) and _is_numeric(value):
return True
return False
except ValueError:
pass


def _generate_column_styles(
column_indices, unique_id, background_color="#FFFFCC", text_color="black"
):
"""
Generate CSS styles to change the background-color of all columns
with data-type mismatch.
Parameters
----------
column_indices (list): List of column indices with data-type mismatch.
unique_id (str): Unique ID for the current table.
background_color (str, optional): Background color for the mismatched columns.
text_color (str, optional): Text color for the mismatched columns.
Returns:
str: HTML style tags containing the CSS styles for the mismatched columns.
"""

styles = ""
for index in column_indices:
styles = f"""{styles}
#profile-table-{unique_id} td:nth-child({index + 1}) {{
background-color: {background_color};
color: {text_color};
}}
"""
return f"<style>{styles}</style>"


def _generate_message(column_indices, columns):
"""Generate a message indicating all columns with a datatype mismatch"""
message = "Columns "
for c in column_indices:
col = columns[c - 1]
message = f"{message}<code>{col}</code>"
message = (
f"{message} have a datatype mismatch -> numeric values stored as a string."
)
message = f"{message} <br> Cannot calculate mean/min/max/std/percentiles"
return message


def _assign_column_specific_stats(col_stats, is_numeric):
"""
Assign NaN values to categorical/numerical specific statistic.
Parameters
----------
col_stats (dict): Dictionary containing column statistics.
is_numeric (bool): Flag indicating whether the column is numeric or not.
Returns:
dict: Updated col_stats dictionary.
"""
categorical_stats = ["top", "freq"]
numerical_stats = ["mean", "min", "max", "std", "25%", "50%", "75%"]

if is_numeric:
for stat in categorical_stats:
col_stats[stat] = math.nan
else:
for stat in numerical_stats:
col_stats[stat] = math.nan

return col_stats


@modify_exceptions
class Columns(DatabaseInspection):
"""
Expand Down Expand Up @@ -108,27 +197,36 @@ def __init__(self, name, schema, conn=None) -> None:
@modify_exceptions
class TableDescription(DatabaseInspection):
"""
Generates descriptive statistics.
Generates descriptive statistics.
--------------------------------------
Descriptive statistics are:
Count - Number of all not None values
Descriptive statistics are:
Mean - Mean of the values
Count - Number of all not None values
Max - Maximum of the values in the object.
Mean - Mean of the values
Min - Minimum of the values in the object.
Max - Maximum of the values in the object.
STD - Standard deviation of the observations
Min - Minimum of the values in the object.
25h, 50h and 75h percentiles
STD - Standard deviation of the observations
Unique - Number of not None unique values
25h, 50h and 75h percentiles
Top - The most frequent value
Unique - Number of not None unique values
Freq - Frequency of the top value
Top - The most frequent value
------------------------------------------
Following statistics will be calculated for :-
Freq - Frequency of the top value
Categorical columns - [Count, Unique, Top, Freq]
Numerical columns - [Count, Unique, Mean, Max, Min,
STD, 25h, 50h and 75h percentiles]
"""

Expand All @@ -141,18 +239,35 @@ def __init__(self, table_name, schema=None) -> None:
columns_query_result = sql.run.raw_run(
Connection.current, f"SELECT * FROM {table_name} WHERE 1=0"
)

if Connection.is_custom_connection():
columns = [i[0] for i in columns_query_result.description]
else:
columns = columns_query_result.keys()

table_stats = dict({})
columns_to_include_in_report = set()
columns_with_styles = []
message_check = False

for column in columns:
for i, column in enumerate(columns):
table_stats[column] = dict()

# check the datatype of a column
try:
result = sql.run.raw_run(
Connection.current, f"""SELECT {column} FROM {table_name} LIMIT 1"""
).fetchone()

value = result[0]
is_numeric = isinstance(value, (int, float)) or (
isinstance(value, str) and _is_numeric(value)
)
except ValueError:
is_numeric = True

if _is_numeric_as_str(column, value):
columns_with_styles.append(i + 1)
message_check = True
# Note: index is reserved word in sqlite
try:
result_col_freq_values = sql.run.raw_run(
Expand Down Expand Up @@ -183,10 +298,12 @@ def __init__(self, table_name, schema=None) -> None:
""",
).fetchall()

table_stats[column]["min"] = result_value_values[0][0]
table_stats[column]["max"] = result_value_values[0][1]
columns_to_include_in_report.update(["count", "min", "max"])
table_stats[column]["count"] = result_value_values[0][2]

table_stats[column]["min"] = round(result_value_values[0][0], 4)
table_stats[column]["max"] = round(result_value_values[0][1], 4)

columns_to_include_in_report.update(["count", "min", "max"])

except Exception:
Expand All @@ -204,9 +321,7 @@ def __init__(self, table_name, schema=None) -> None:
""",
).fetchall()
table_stats[column]["unique"] = result_value_values[0][0]

columns_to_include_in_report.update(["unique"])

except Exception:
pass

Expand All @@ -220,8 +335,8 @@ def __init__(self, table_name, schema=None) -> None:
""",
).fetchall()

table_stats[column]["mean"] = float(results_avg[0][0])
columns_to_include_in_report.update(["mean"])
table_stats[column]["mean"] = format(float(results_avg[0][0]), ".4f")

except Exception:
table_stats[column]["mean"] = math.nan
Expand All @@ -246,11 +361,10 @@ def __init__(self, table_name, schema=None) -> None:
""",
).fetchall()

columns_to_include_in_report.update(special_numeric_keys)
for i, key in enumerate(special_numeric_keys):
# r_key = f'key_{key.replace("%", "")}'
table_stats[column][key] = float(result[0][i])

columns_to_include_in_report.update(special_numeric_keys)
table_stats[column][key] = format(float(result[0][i]), ".4f")

except TypeError:
# for non numeric values
Expand All @@ -268,39 +382,97 @@ def __init__(self, table_name, schema=None) -> None:
# We ignore the cell stats for such case.
pass

table_stats[column] = _assign_column_specific_stats(
table_stats[column], is_numeric
)

self._table = PrettyTable()
self._table.field_names = [" "] + list(table_stats.keys())

rows = list(columns_to_include_in_report)
rows.sort(reverse=True)
for row in rows:
values = [row]
for column in table_stats:
if row in table_stats[column]:
value = table_stats[column][row]
else:
value = ""
value = util.convert_to_scientific(value)
values.append(value)
custom_order = [
"count",
"unique",
"top",
"freq",
"mean",
"std",
"min",
"25%",
"50%",
"75%",
"max",
]

for row in custom_order:
if row.lower() in [r.lower() for r in columns_to_include_in_report]:
values = [row]
for column in table_stats:
if row in table_stats[column]:
value = table_stats[column][row]
else:
value = ""
# value = util.convert_to_scientific(value)
values.append(value)

self._table.add_row(values)

unique_id = str(uuid.uuid4()).replace("-", "")
column_styles = _generate_column_styles(columns_with_styles, unique_id)

if message_check:
message_content = _generate_message(columns_with_styles, list(columns))
warning_background = "#FFFFCC"
warning_title = "Warning: "
else:
message_content = ""
warning_background = "white"
warning_title = ""

database = Connection.current.url
db_driver = Connection.current._get_curr_sqlalchemy_connection_info()["driver"]
if "duckdb" in database:
db_message = ""
else:
db_message = f"""Following statistics are not available in
{db_driver}: STD, 25%, 50%, 75%"""

db_html = (
f"<div style='position: sticky; left: 0; padding: 10px; "
f"font-size: 12px; color: #FFA500'>"
f"<strong></strong> {db_message}"
"</div>"
)

self._table.add_row(values)
message_html = (
f"<div style='position: sticky; left: 0; padding: 10px; "
f"font-size: 12px; color: black; background-color: {warning_background};'>"
f"<strong>{warning_title}</strong> {message_content}"
"</div>"
)

# Inject css to html to make first column sticky
sticky_column_css = """<style>
#profile-table td:first-child {
position: sticky;
left: 0;
background-color: var(--jp-cell-editor-background);
font-weight: bold;
}
#profile-table thead tr th:first-child {
position: sticky;
left: 0;
background-color: var(--jp-cell-editor-background);
font-weight: bold; /* Adding bold text */
}
</style>"""
self._table_html = HTML(
sticky_column_css
+ self._table.get_html_string(attributes={"id": "profile-table"})
db_html
+ sticky_column_css
+ column_styles
+ self._table.get_html_string(
attributes={"id": f"profile-table-{unique_id}"}
)
+ message_html
).__html__()

self._table_txt = self._table.get_string()
Expand Down
Loading

0 comments on commit 0efeb65

Please sign in to comment.