diff --git a/tests/test_preprocessing.py b/tests/test_preprocessing.py index 4ca3ace2..84f9071d 100644 --- a/tests/test_preprocessing.py +++ b/tests/test_preprocessing.py @@ -381,3 +381,60 @@ def test_remove_hashtags(self): s_true = pd.Series("Hi , we will remove you") self.assertEqual(preprocessing.remove_hashtags(s), s_true) + + """ + Test describe DataFrame + """ + + def test_describe(self): + df = pd.DataFrame( + [ + ["here here here here go", "sport"], + ["There There There", "sport"], + ["Test, Test, Test, Test, Test, Test, Test, Test", "sport"], + [np.nan, "music"], + ["super super", pd.NA], + [pd.NA, pd.NA], + ["great great great great great", "music"], + ], + columns=["text", "topics"], + ) + df_description = preprocessing.describe(df["text"], df["topics"]) + df_true = pd.DataFrame( + [ + 7, + 7, + 2, + ["Test", "great", "here", "There", "super", "go"], + ["test", "great", "super", "go"], + 6.0, + 2.0, + 15.0, + 5.196152422706632, + 3.0, + 5.0, + 5.0, + 0.6, + 0.4, + ], + columns=["Value"], + index=pd.MultiIndex.from_tuples( + [ + ("number of documents", ""), + ("number of unique documents", ""), + ("number of missing documents", ""), + ("most common words", ""), + ("most common words excluding stopwords", ""), + ("average document length", ""), + ("length of shortest document", ""), + ("length of longest document", ""), + ("standard deviation of document lengths", ""), + ("25th percentile document lengths", ""), + ("50th percentile document lengths", ""), + ("75th percentile document lengths", ""), + ("label distribution", "sport"), + ("label distribution", "music"), + ] + ), + ) + pd.testing.assert_frame_equal(df_description, df_true, check_less_precise=True) diff --git a/tests/test_visualization.py b/tests/test_visualization.py index d0075389..963263e1 100644 --- a/tests/test_visualization.py +++ b/tests/test_visualization.py @@ -1,6 +1,7 @@ import string import pandas as pd +import plotly import doctest from texthero import visualization @@ -79,3 +80,19 @@ def test_top_words_digits_punctuation(self): def test_wordcloud(self): s = pd.Series("one two three") self.assertEqual(visualization.wordcloud(s), None) + + """ + Test visualization of describe function + """ + + def test_visualisation_describe(self): + df = pd.DataFrame( + [["one two three", "here"], ["one two three", "here"]], + columns=["text", "topic"], + ) + self.assertIsInstance( + visualization.visualize_describe( + df["text"], df["topic"], return_figure=True + ), + plotly.graph_objs._figure.Figure, + ) diff --git a/texthero/preprocessing.py b/texthero/preprocessing.py index 747c9598..6d9898af 100644 --- a/texthero/preprocessing.py +++ b/texthero/preprocessing.py @@ -14,6 +14,7 @@ from texthero import stopwords as _stopwords from texthero._types import TokenSeries, TextSeries, InputSeries +from texthero import visualization from typing import List, Callable, Union @@ -906,3 +907,100 @@ def remove_hashtags(s: TextSeries) -> TextSeries: with a custom symbol. """ return replace_hashtags(s, " ") + + +@InputSeries(TextSeries) +def describe(s: TextSeries, s_labels: pd.Series = None) -> pd.DataFrame: + """ + Describe a given pandas TextSeries (consisting of strings + in every cell). Additionally gather information + about class labels if they are given in s_labels. + + Examples + -------- + >>> import texthero as hero + >>> import pandas as pd + >>> df = pd.read_csv("https://raw.githubusercontent.com/jbesomi/texthero/master/dataset/bbcsport.csv") # doctest: +SKIP + >>> df.head(2) # doctest: +SKIP + text topic + 0 Claxton hunting first major medal\n\nBritish h... athletics + 1 O'Sullivan could run in Worlds\n\nSonia O'Sull... athletics + >>> # Describe both the text and the labels + >>> hero.describe(df["text"], df["topic"]) # doctest: +SKIP + Value + number of documents 737 + number of unique documents 727 + number of missing documents 0 + most common words [the, to, a, in, and, of, for, ", I, is] + most common words excluding stopwords [said, first, england, game, one, year, two, w... + average document length 387.803 + length of shortest document 119 + length of longest document 1855 + standard deviation of document lengths 210.728 + 25th percentile document lengths 241 + 50th percentile document lengths 340 + 75th percentile document lengths 494 + label distribution football 0.359566 + rugby 0.199457 + cricket 0.16825 + athletics 0.137042 + tennis 0.135685 + """ + # Get values we need for several calculations. + description = {} + s_tokenized = tokenize(s) + has_content_mask = has_content(s) + document_lengths = s_tokenized[has_content_mask].map(lambda x: len(x)) + document_lengths_description = document_lengths.describe() + + # Collect statistics. + description["number of documents"] = len(s.index) + description["number of unique documents"] = len(s.unique()) + description["number of missing documents"] = (~has_content_mask).sum() + description["most common words"] = visualization.top_words(s).index[:10].tolist() + description["most common words excluding stopwords"] = ( + s.pipe(clean).pipe(visualization.top_words).index[:10].tolist() + ) + + description["average document length"] = document_lengths_description["mean"] + description["length of shortest document"] = document_lengths_description["min"] + description["length of longest document"] = document_lengths_description["max"] + description[ + "standard deviation of document lengths" + ] = document_lengths_description["std"] + description["25th percentile document lengths"] = document_lengths_description[ + "25%" + ] + description["50th percentile document lengths"] = document_lengths_description[ + "50%" + ] + description["75th percentile document lengths"] = document_lengths_description[ + "75%" + ] + + # Create output Series. + s_description = pd.Series(description) + + # Potentially add information about label distribution. + if s_labels is not None: + + s_labels_distribution = s_labels.value_counts() / s_labels.value_counts().sum() + + # Put the labels distribution into s_description with multiindex to look nice. + s_labels_distribution.index = pd.MultiIndex.from_product( + [["label distribution"], s_labels_distribution.index.values] + ) + + s_description.index = pd.MultiIndex.from_product( + [s_description.index.values, [""]] + ) + + s_description = pd.concat([s_description, s_labels_distribution]) + + # DataFrame will look much nicer for users when printing. + df_description = pd.DataFrame( + s_description.values, index=s_description.index, columns=["Value"] + ) + df_description.index.name = "Statistic" + + return df_description diff --git a/texthero/visualization.py b/texthero/visualization.py index a7b2b83c..1f92191e 100644 --- a/texthero/visualization.py +++ b/texthero/visualization.py @@ -2,20 +2,24 @@ Visualize insights and statistics of a text-based Pandas DataFrame. """ +import plotly.graph_objects as go +from plotly.subplots import make_subplots +import os import pandas as pd import numpy as np import plotly.express as px +import warnings from wordcloud import WordCloud from texthero import preprocessing from texthero._types import TextSeries, InputSeries -import string from matplotlib.colors import LinearSegmentedColormap as lsg import matplotlib.pyplot as plt from collections import Counter +import string def scatterplot( @@ -304,3 +308,162 @@ def top_words(s: TextSeries, normalize=False) -> pd.Series: .explode() # one word for each line .value_counts(normalize=normalize) ) + + +def visualize_describe(s: TextSeries, s_labels: pd.Series = None, return_figure=False): + """ + Visualize statistics about a given TextSeries, and + optionally a given Series with labels/classes. + + This function visualizes the output of + :meth:`texthero.preprocessing.describe`. + + Parameters + ---------- + s: TextSeries + The Series that should be described. + + s_labels : pd.Series + A Series with the labels / classes / topics + of the texts in the first argument. + + return_figure : bool, default to False + Whether to return the figure instead of showing it. + + Examples + -------- + >>> import texthero as hero + >>> import pandas as pd + >>> df = pd.read_csv("https://raw.githubusercontent.com/jbesomi/texthero/master/dataset/bbcsport.csv") # doctest: +SKIP + >>> df.head(2) # doctest: +SKIP + text topic + 0 Claxton hunting first major medal\n\nBritish h... athletics + 1 O'Sullivan could run in Worlds\n\nSonia O'Sull... athletics + >>> # Describe both the text and the labels + >>> hero.visualize_describe(df["text"], df["topic"]) # doctest: +SKIP + """ + + # Gather data (most from hero.describe, just + # the document lengths histogram is calculated here). + s_tokenized = preprocessing.tokenize(s) + has_content_mask = preprocessing.has_content(s) + s_document_lengths = s_tokenized[has_content_mask].map(lambda x: len(x)) + + document_lengths_histogram = np.histogram(s_document_lengths.values, bins=20) + + document_lengths_histogram_df = pd.DataFrame( + { + "Document Length": np.insert(document_lengths_histogram[0], 0, 0), + "Number of Documents": document_lengths_histogram[1], + } + ) + + description = preprocessing.describe(s, s_labels) + + # Initialize Figure + fig = make_subplots( + rows=2, + cols=2, + specs=[ + [{"type": "sankey"}, {"type": "table"}], + [{"type": "scatter"}, {"type": "pie"}], + ], + column_widths=[0.7, 0.3], + ) + + # Create pie chart of label distribution if it was calculated. + if "label distribution" in description.index: + label_distribution_pie_chart_df = description.loc["label distribution"] + label_distribution_pie_chart_fig = go.Pie( + labels=label_distribution_pie_chart_df.index.tolist(), + values=label_distribution_pie_chart_df.values.flatten().tolist(), + title="Label Distributions", + ) + else: + label_distribution_pie_chart_fig = None + + # Create histogram of document lengths + document_lengths_fig = go.Scatter( + x=document_lengths_histogram_df["Number of Documents"], + y=document_lengths_histogram_df["Document Length"], + fill="tozeroy", + name="Document Length Histogram", + showlegend=False, + ) + + if s_labels is not None: # labels given -> description output is multiindexed + n_total_docs = description.loc["number of documents"].values[0][0] + n_unique_docs = description.loc["number of unique documents"].values[0][0] + n_missing_docs = description.loc["number of missing documents"].values[0][0] + most_common_words = description.loc["most common words"].values[0][0] + most_common_words_excluding_stopwords = description.loc[ + "most common words excluding stopwords" + ].values[0][0] + else: + n_total_docs = description.loc["number of documents"].values[0] + n_unique_docs = description.loc["number of unique documents"].values[0] + n_missing_docs = description.loc["number of missing documents"].values[0] + most_common_words = description.loc["most common words"].values[0] + most_common_words_excluding_stopwords = description.loc[ + "most common words excluding stopwords" + ].values[0] + + # Create bar charts for documents / unique / missing + n_duplicate_docs = n_total_docs - n_unique_docs - n_missing_docs + + schart = go.Sankey( + node=dict( + pad=15, + thickness=20, + label=[ + "Total Number of Documents", + "Duplicate Documents", + "Unique Documents", + "Missing Documents", + ], + color=[ + "rgba(122,122,255,0.8)", + "rgba(255,153,51,0.8)", + "rgba(141,211,199,0.8)", + "rgba(235,83,83,0.8)", + ], + ), + link=dict( + # indices correspond to labels, eg A1, A2, A2, B1, ... + source=[0, 0, 0], + target=[2, 1, 3], + color=[ + "rgba(179,226,205,0.6)", + "rgba(250,201,152,0.6)", + "rgba(255,134,134,0.6)", + ], + value=[n_unique_docs, n_duplicate_docs, n_missing_docs,], + ), + ) + + # Create Table to show the 10 most common words (with and without stopwords) + table = go.Table( + header=dict(values=["Top Words with Stopwords", "Top Words without Stopwords"]), + cells=dict(values=[most_common_words, most_common_words_excluding_stopwords,]), + ) + + # Combine figures. + if label_distribution_pie_chart_fig is not None: + fig.add_trace(label_distribution_pie_chart_fig, row=2, col=2) + + fig.add_trace(document_lengths_fig, row=2, col=1) + + fig.add_trace(schart, row=1, col=1) + + fig.add_trace(table, row=1, col=2) + + # Style and show figure. + fig.update_layout(plot_bgcolor="rgb(255,255,255)", barmode="stack") + fig.update_xaxes(title_text="Document Length", row=2, col=1) + fig.update_yaxes(title_text="Number of Documents", row=2, col=1) + fig.update_layout(legend=dict(yanchor="bottom", y=0, x=1.1, xanchor="right",)) + + if return_figure: + return fig + else: + fig.show()