diff --git a/.changeset/lucky-towns-allow.md b/.changeset/lucky-towns-allow.md new file mode 100644 index 0000000000000..aa1e59ba64405 --- /dev/null +++ b/.changeset/lucky-towns-allow.md @@ -0,0 +1,5 @@ +--- +"gradio": patch +--- + +fix:Clean up `gr.DataFrame.postprocess()` and fix issue with getting headers of empty dataframes diff --git a/demo/mini_leaderboard/run.ipynb b/demo/mini_leaderboard/run.ipynb index db8e589ee0769..1b5e62a09f4c3 100644 --- a/demo/mini_leaderboard/run.ipynb +++ b/demo/mini_leaderboard/run.ipynb @@ -1 +1 @@ -{"cells": [{"cell_type": "markdown", "id": "302934307671667531413257853548643485645", "metadata": {}, "source": ["# Gradio Demo: mini_leaderboard"]}, {"cell_type": "code", "execution_count": null, "id": "272996653310673477252411125948039410165", "metadata": {}, "outputs": [], "source": ["!pip install -q gradio pandas "]}, {"cell_type": "code", "execution_count": null, "id": "288918539441861185822528903084949547379", "metadata": {}, "outputs": [], "source": ["# Downloading files from the demo repo\n", "import os\n", "os.mkdir('assets')\n", "!wget -q -O assets/__init__.py https://github.com/gradio-app/gradio/raw/main/demo/mini_leaderboard/assets/__init__.py\n", "!wget -q -O assets/custom_css.css https://github.com/gradio-app/gradio/raw/main/demo/mini_leaderboard/assets/custom_css.css\n", "!wget -q -O assets/leaderboard_data.json https://github.com/gradio-app/gradio/raw/main/demo/mini_leaderboard/assets/leaderboard_data.json"]}, {"cell_type": "code", "execution_count": null, "id": "44380577570523278879349135829904343037", "metadata": {}, "outputs": [], "source": ["import gradio as gr\n", "import pandas as pd\n", "from pathlib import Path\n", "\n", "abs_path = Path(__file__).parent.absolute()\n", "\n", "df = pd.read_json(str(abs_path / \"assets/leaderboard_data.json\"))\n", "invisible_df = df.copy()\n", "\n", "COLS = [\n", " \"T\",\n", " \"Model\",\n", " \"Average \u2b06\ufe0f\",\n", " \"ARC\",\n", " \"HellaSwag\",\n", " \"MMLU\",\n", " \"TruthfulQA\",\n", " \"Winogrande\",\n", " \"GSM8K\",\n", " \"Type\",\n", " \"Architecture\",\n", " \"Precision\",\n", " \"Merged\",\n", " \"Hub License\",\n", " \"#Params (B)\",\n", " \"Hub \u2764\ufe0f\",\n", " \"Model sha\",\n", " \"model_name_for_query\",\n", "]\n", "ON_LOAD_COLS = [\n", " \"T\",\n", " \"Model\",\n", " \"Average \u2b06\ufe0f\",\n", " \"ARC\",\n", " \"HellaSwag\",\n", " \"MMLU\",\n", " \"TruthfulQA\",\n", " \"Winogrande\",\n", " \"GSM8K\",\n", " \"model_name_for_query\",\n", "]\n", "TYPES = [\n", " \"str\",\n", " \"markdown\",\n", " \"number\",\n", " \"number\",\n", " \"number\",\n", " \"number\",\n", " \"number\",\n", " \"number\",\n", " \"number\",\n", " \"str\",\n", " \"str\",\n", " \"str\",\n", " \"str\",\n", " \"bool\",\n", " \"str\",\n", " \"number\",\n", " \"number\",\n", " \"bool\",\n", " \"str\",\n", " \"bool\",\n", " \"bool\",\n", " \"str\",\n", "]\n", "NUMERIC_INTERVALS = {\n", " \"?\": pd.Interval(-1, 0, closed=\"right\"),\n", " \"~1.5\": pd.Interval(0, 2, closed=\"right\"),\n", " \"~3\": pd.Interval(2, 4, closed=\"right\"),\n", " \"~7\": pd.Interval(4, 9, closed=\"right\"),\n", " \"~13\": pd.Interval(9, 20, closed=\"right\"),\n", " \"~35\": pd.Interval(20, 45, closed=\"right\"),\n", " \"~60\": pd.Interval(45, 70, closed=\"right\"),\n", " \"70+\": pd.Interval(70, 10000, closed=\"right\"),\n", "}\n", "MODEL_TYPE = [str(s) for s in df[\"T\"].unique()]\n", "Precision = [str(s) for s in df[\"Precision\"].unique()]\n", "\n", "# Searching and filtering\n", "def update_table(\n", " hidden_df: pd.DataFrame,\n", " columns: list,\n", " type_query: list,\n", " precision_query: str,\n", " size_query: list,\n", " query: str,\n", "):\n", " filtered_df = filter_models(hidden_df, type_query, size_query, precision_query) # type: ignore\n", " filtered_df = filter_queries(query, filtered_df)\n", " df = select_columns(filtered_df, columns)\n", " return df\n", "\n", "def search_table(df: pd.DataFrame, query: str) -> pd.DataFrame:\n", " return df[(df[\"model_name_for_query\"].str.contains(query, case=False))] # type: ignore\n", "\n", "def select_columns(df: pd.DataFrame, columns: list) -> pd.DataFrame:\n", " # We use COLS to maintain sorting\n", " filtered_df = df[[c for c in COLS if c in df.columns and c in columns]]\n", " return filtered_df # type: ignore\n", "\n", "def filter_queries(query: str, filtered_df: pd.DataFrame) -> pd.DataFrame:\n", " final_df = []\n", " if query != \"\":\n", " queries = [q.strip() for q in query.split(\";\")]\n", " for _q in queries:\n", " _q = _q.strip()\n", " if _q != \"\":\n", " temp_filtered_df = search_table(filtered_df, _q)\n", " if len(temp_filtered_df) > 0:\n", " final_df.append(temp_filtered_df)\n", " if len(final_df) > 0:\n", " filtered_df = pd.concat(final_df)\n", " filtered_df = filtered_df.drop_duplicates( # type: ignore\n", " subset=[\"Model\", \"Precision\", \"Model sha\"]\n", " )\n", "\n", " return filtered_df\n", "\n", "def filter_models(\n", " df: pd.DataFrame,\n", " type_query: list,\n", " size_query: list,\n", " precision_query: list,\n", ") -> pd.DataFrame:\n", " # Show all models\n", " filtered_df = df\n", "\n", " type_emoji = [t[0] for t in type_query]\n", " filtered_df = filtered_df.loc[df[\"T\"].isin(type_emoji)]\n", " filtered_df = filtered_df.loc[df[\"Precision\"].isin(precision_query + [\"None\"])]\n", "\n", " numeric_interval = pd.IntervalIndex(\n", " sorted([NUMERIC_INTERVALS[s] for s in size_query]) # type: ignore\n", " )\n", " params_column = pd.to_numeric(df[\"#Params (B)\"], errors=\"coerce\")\n", " mask = params_column.apply(lambda x: any(numeric_interval.contains(x))) # type: ignore\n", " filtered_df = filtered_df.loc[mask]\n", "\n", " return filtered_df\n", "\n", "demo = gr.Blocks(css=str(abs_path / \"assets/leaderboard_data.json\"))\n", "with demo:\n", " gr.Markdown(\"\"\"Test Space of the LLM Leaderboard\"\"\", elem_classes=\"markdown-text\")\n", "\n", " with gr.Tabs(elem_classes=\"tab-buttons\") as tabs:\n", " with gr.TabItem(\"\ud83c\udfc5 LLM Benchmark\", elem_id=\"llm-benchmark-tab-table\", id=0):\n", " with gr.Row():\n", " with gr.Column():\n", " with gr.Row():\n", " search_bar = gr.Textbox(\n", " placeholder=\" \ud83d\udd0d Search for your model (separate multiple queries with `;`) and press ENTER...\",\n", " show_label=False,\n", " elem_id=\"search-bar\",\n", " )\n", " with gr.Row():\n", " shown_columns = gr.CheckboxGroup(\n", " choices=COLS,\n", " value=ON_LOAD_COLS,\n", " label=\"Select columns to show\",\n", " elem_id=\"column-select\",\n", " interactive=True,\n", " )\n", " with gr.Column(min_width=320):\n", " filter_columns_type = gr.CheckboxGroup(\n", " label=\"Model types\",\n", " choices=MODEL_TYPE,\n", " value=MODEL_TYPE,\n", " interactive=True,\n", " elem_id=\"filter-columns-type\",\n", " )\n", " filter_columns_precision = gr.CheckboxGroup(\n", " label=\"Precision\",\n", " choices=Precision,\n", " value=Precision,\n", " interactive=True,\n", " elem_id=\"filter-columns-precision\",\n", " )\n", " filter_columns_size = gr.CheckboxGroup(\n", " label=\"Model sizes (in billions of parameters)\",\n", " choices=list(NUMERIC_INTERVALS.keys()),\n", " value=list(NUMERIC_INTERVALS.keys()),\n", " interactive=True,\n", " elem_id=\"filter-columns-size\",\n", " )\n", "\n", " leaderboard_table = gr.components.Dataframe(\n", " value=df[ON_LOAD_COLS], # type: ignore\n", " headers=ON_LOAD_COLS,\n", " datatype=TYPES,\n", " elem_id=\"leaderboard-table\",\n", " interactive=False,\n", " visible=True,\n", " column_widths=[\"2%\", \"33%\"],\n", " )\n", "\n", " # Dummy leaderboard for handling the case when the user uses backspace key\n", " hidden_leaderboard_table_for_search = gr.components.Dataframe(\n", " value=invisible_df[COLS], # type: ignore\n", " headers=COLS,\n", " datatype=TYPES,\n", " visible=False,\n", " )\n", " search_bar.submit(\n", " update_table,\n", " [\n", " hidden_leaderboard_table_for_search,\n", " shown_columns,\n", " filter_columns_type,\n", " filter_columns_precision,\n", " filter_columns_size,\n", " search_bar,\n", " ],\n", " leaderboard_table,\n", " )\n", " for selector in [\n", " shown_columns,\n", " filter_columns_type,\n", " filter_columns_precision,\n", " filter_columns_size,\n", " ]:\n", " selector.change(\n", " update_table,\n", " [\n", " hidden_leaderboard_table_for_search,\n", " shown_columns,\n", " filter_columns_type,\n", " filter_columns_precision,\n", " filter_columns_size,\n", " search_bar,\n", " ],\n", " leaderboard_table,\n", " queue=True,\n", " )\n", "\n", "if __name__ == \"__main__\":\n", " demo.queue(default_concurrency_limit=40).launch()\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5} \ No newline at end of file +{"cells": [{"cell_type": "markdown", "id": "302934307671667531413257853548643485645", "metadata": {}, "source": ["# Gradio Demo: mini_leaderboard"]}, {"cell_type": "code", "execution_count": null, "id": "272996653310673477252411125948039410165", "metadata": {}, "outputs": [], "source": ["!pip install -q gradio pandas "]}, {"cell_type": "code", "execution_count": null, "id": "288918539441861185822528903084949547379", "metadata": {}, "outputs": [], "source": ["# Downloading files from the demo repo\n", "import os\n", "os.mkdir('assets')\n", "!wget -q -O assets/__init__.py https://github.com/gradio-app/gradio/raw/main/demo/mini_leaderboard/assets/__init__.py\n", "!wget -q -O assets/custom_css.css https://github.com/gradio-app/gradio/raw/main/demo/mini_leaderboard/assets/custom_css.css\n", "!wget -q -O assets/leaderboard_data.json https://github.com/gradio-app/gradio/raw/main/demo/mini_leaderboard/assets/leaderboard_data.json"]}, {"cell_type": "code", "execution_count": null, "id": "44380577570523278879349135829904343037", "metadata": {}, "outputs": [], "source": ["# type: ignore\n", "import gradio as gr\n", "import pandas as pd\n", "from pathlib import Path\n", "\n", "abs_path = Path(__file__).parent.absolute()\n", "\n", "df = pd.read_json(str(abs_path / \"assets/leaderboard_data.json\"))\n", "invisible_df = df.copy()\n", "\n", "COLS = [\n", " \"T\",\n", " \"Model\",\n", " \"Average \u2b06\ufe0f\",\n", " \"ARC\",\n", " \"HellaSwag\",\n", " \"MMLU\",\n", " \"TruthfulQA\",\n", " \"Winogrande\",\n", " \"GSM8K\",\n", " \"Type\",\n", " \"Architecture\",\n", " \"Precision\",\n", " \"Merged\",\n", " \"Hub License\",\n", " \"#Params (B)\",\n", " \"Hub \u2764\ufe0f\",\n", " \"Model sha\",\n", " \"model_name_for_query\",\n", "]\n", "ON_LOAD_COLS = [\n", " \"T\",\n", " \"Model\",\n", " \"Average \u2b06\ufe0f\",\n", " \"ARC\",\n", " \"HellaSwag\",\n", " \"MMLU\",\n", " \"TruthfulQA\",\n", " \"Winogrande\",\n", " \"GSM8K\",\n", " \"model_name_for_query\",\n", "]\n", "TYPES = [\n", " \"str\",\n", " \"markdown\",\n", " \"number\",\n", " \"number\",\n", " \"number\",\n", " \"number\",\n", " \"number\",\n", " \"number\",\n", " \"number\",\n", " \"str\",\n", " \"str\",\n", " \"str\",\n", " \"str\",\n", " \"bool\",\n", " \"str\",\n", " \"number\",\n", " \"number\",\n", " \"bool\",\n", " \"str\",\n", " \"bool\",\n", " \"bool\",\n", " \"str\",\n", "]\n", "NUMERIC_INTERVALS = {\n", " \"?\": pd.Interval(-1, 0, closed=\"right\"),\n", " \"~1.5\": pd.Interval(0, 2, closed=\"right\"),\n", " \"~3\": pd.Interval(2, 4, closed=\"right\"),\n", " \"~7\": pd.Interval(4, 9, closed=\"right\"),\n", " \"~13\": pd.Interval(9, 20, closed=\"right\"),\n", " \"~35\": pd.Interval(20, 45, closed=\"right\"),\n", " \"~60\": pd.Interval(45, 70, closed=\"right\"),\n", " \"70+\": pd.Interval(70, 10000, closed=\"right\"),\n", "}\n", "MODEL_TYPE = [str(s) for s in df[\"T\"].unique()]\n", "Precision = [str(s) for s in df[\"Precision\"].unique()]\n", "\n", "# Searching and filtering\n", "def update_table(\n", " hidden_df: pd.DataFrame,\n", " columns: list,\n", " type_query: list,\n", " precision_query: str,\n", " size_query: list,\n", " query: str,\n", "):\n", " filtered_df = filter_models(hidden_df, type_query, size_query, precision_query) # type: ignore\n", " filtered_df = filter_queries(query, filtered_df)\n", " df = select_columns(filtered_df, columns)\n", " return df\n", "\n", "def search_table(df: pd.DataFrame, query: str) -> pd.DataFrame:\n", " return df[(df[\"model_name_for_query\"].str.contains(query, case=False))] # type: ignore\n", "\n", "def select_columns(df: pd.DataFrame, columns: list) -> pd.DataFrame:\n", " # We use COLS to maintain sorting\n", " filtered_df = df[[c for c in COLS if c in df.columns and c in columns]]\n", " return filtered_df # type: ignore\n", "\n", "def filter_queries(query: str, filtered_df: pd.DataFrame) -> pd.DataFrame:\n", " final_df = []\n", " if query != \"\":\n", " queries = [q.strip() for q in query.split(\";\")]\n", " for _q in queries:\n", " _q = _q.strip()\n", " if _q != \"\":\n", " temp_filtered_df = search_table(filtered_df, _q)\n", " if len(temp_filtered_df) > 0:\n", " final_df.append(temp_filtered_df)\n", " if len(final_df) > 0:\n", " filtered_df = pd.concat(final_df)\n", " filtered_df = filtered_df.drop_duplicates( # type: ignore\n", " subset=[\"Model\", \"Precision\", \"Model sha\"]\n", " )\n", "\n", " return filtered_df\n", "\n", "def filter_models(\n", " df: pd.DataFrame,\n", " type_query: list,\n", " size_query: list,\n", " precision_query: list,\n", ") -> pd.DataFrame:\n", " # Show all models\n", " filtered_df = df\n", "\n", " type_emoji = [t[0] for t in type_query]\n", " filtered_df = filtered_df.loc[df[\"T\"].isin(type_emoji)]\n", " filtered_df = filtered_df.loc[df[\"Precision\"].isin(precision_query + [\"None\"])]\n", "\n", " numeric_interval = pd.IntervalIndex(\n", " sorted([NUMERIC_INTERVALS[s] for s in size_query]) # type: ignore\n", " )\n", " params_column = pd.to_numeric(df[\"#Params (B)\"], errors=\"coerce\")\n", " mask = params_column.apply(lambda x: any(numeric_interval.contains(x))) # type: ignore\n", " filtered_df = filtered_df.loc[mask]\n", "\n", " return filtered_df\n", "\n", "demo = gr.Blocks(css=str(abs_path / \"assets/leaderboard_data.json\"))\n", "with demo:\n", " gr.Markdown(\"\"\"Test Space of the LLM Leaderboard\"\"\", elem_classes=\"markdown-text\")\n", "\n", " with gr.Tabs(elem_classes=\"tab-buttons\") as tabs:\n", " with gr.TabItem(\"\ud83c\udfc5 LLM Benchmark\", elem_id=\"llm-benchmark-tab-table\", id=0):\n", " with gr.Row():\n", " with gr.Column():\n", " with gr.Row():\n", " search_bar = gr.Textbox(\n", " placeholder=\" \ud83d\udd0d Search for your model (separate multiple queries with `;`) and press ENTER...\",\n", " show_label=False,\n", " elem_id=\"search-bar\",\n", " )\n", " with gr.Row():\n", " shown_columns = gr.CheckboxGroup(\n", " choices=COLS,\n", " value=ON_LOAD_COLS,\n", " label=\"Select columns to show\",\n", " elem_id=\"column-select\",\n", " interactive=True,\n", " )\n", " with gr.Column(min_width=320):\n", " filter_columns_type = gr.CheckboxGroup(\n", " label=\"Model types\",\n", " choices=MODEL_TYPE,\n", " value=MODEL_TYPE,\n", " interactive=True,\n", " elem_id=\"filter-columns-type\",\n", " )\n", " filter_columns_precision = gr.CheckboxGroup(\n", " label=\"Precision\",\n", " choices=Precision,\n", " value=Precision,\n", " interactive=True,\n", " elem_id=\"filter-columns-precision\",\n", " )\n", " filter_columns_size = gr.CheckboxGroup(\n", " label=\"Model sizes (in billions of parameters)\",\n", " choices=list(NUMERIC_INTERVALS.keys()),\n", " value=list(NUMERIC_INTERVALS.keys()),\n", " interactive=True,\n", " elem_id=\"filter-columns-size\",\n", " )\n", "\n", " leaderboard_table = gr.components.Dataframe(\n", " value=df[ON_LOAD_COLS], # type: ignore\n", " headers=ON_LOAD_COLS,\n", " datatype=TYPES,\n", " elem_id=\"leaderboard-table\",\n", " interactive=False,\n", " visible=True,\n", " column_widths=[\"2%\", \"33%\"],\n", " )\n", "\n", " # Dummy leaderboard for handling the case when the user uses backspace key\n", " hidden_leaderboard_table_for_search = gr.components.Dataframe(\n", " value=invisible_df[COLS], # type: ignore\n", " headers=COLS,\n", " datatype=TYPES,\n", " visible=False,\n", " )\n", " search_bar.submit(\n", " update_table,\n", " [\n", " hidden_leaderboard_table_for_search,\n", " shown_columns,\n", " filter_columns_type,\n", " filter_columns_precision,\n", " filter_columns_size,\n", " search_bar,\n", " ],\n", " leaderboard_table,\n", " )\n", " for selector in [\n", " shown_columns,\n", " filter_columns_type,\n", " filter_columns_precision,\n", " filter_columns_size,\n", " ]:\n", " selector.change(\n", " update_table,\n", " [\n", " hidden_leaderboard_table_for_search,\n", " shown_columns,\n", " filter_columns_type,\n", " filter_columns_precision,\n", " filter_columns_size,\n", " search_bar,\n", " ],\n", " leaderboard_table,\n", " queue=True,\n", " )\n", "\n", "if __name__ == \"__main__\":\n", " demo.queue(default_concurrency_limit=40).launch()\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5} \ No newline at end of file diff --git a/demo/mini_leaderboard/run.py b/demo/mini_leaderboard/run.py index 6eef4c3bc6a91..40711d02e76ad 100644 --- a/demo/mini_leaderboard/run.py +++ b/demo/mini_leaderboard/run.py @@ -1,3 +1,4 @@ +# type: ignore import gradio as gr import pandas as pd from pathlib import Path diff --git a/gradio/components/dataframe.py b/gradio/components/dataframe.py index a41c73aeb30f5..4440beb1acf46 100644 --- a/gradio/components/dataframe.py +++ b/gradio/components/dataframe.py @@ -74,7 +74,10 @@ def __init__( headers: list[str] | None = None, row_count: int | tuple[int, str] = (1, "dynamic"), col_count: int | tuple[int, str] | None = None, - datatype: str | list[str] = "str", + datatype: Literal["str", "number", "bool", "date", "markdown", "html"] + | Sequence[ + Literal["str", "number", "bool", "date", "markdown", "html"] + ] = "str", type: Literal["pandas", "numpy", "array", "polars"] = "pandas", latex_delimiters: list[dict[str, str | bool]] | None = None, label: str | None = None, @@ -99,8 +102,8 @@ def __init__( ): """ Parameters: - value: Default value to display in the DataFrame. If a Styler is provided, it will be used to set the displayed value in the DataFrame (e.g. to set precision of numbers) if the `interactive` is False. If a Callable function is provided, the function will be called whenever the app loads to set the initial value of the component. - headers: List of str header names. If None, no headers are shown. + value: Default value to display in the DataFrame. Supports pandas, numpy, polars, and list of lists. If a Styler is provided, it will be used to set the displayed value in the DataFrame (e.g. to set precision of numbers) if the `interactive` is False. If a Callable function is provided, the function will be called whenever the app loads to set the initial value of the component. + headers: List of str header names. These are used to set the column headers of the dataframe if the value does not have headers. If None, no headers are shown. row_count: Limit number of rows for input and decide whether user can create new rows or delete existing rows. The first element of the tuple is an `int`, the row count; the second should be 'fixed' or 'dynamic', the new row behaviour. If an `int` is passed the rows default to 'dynamic' col_count: Limit number of columns for input and decide whether user can create new columns or delete existing columns. The first element of the tuple is an `int`, the number of columns; the second should be 'fixed' or 'dynamic', the new column behaviour. If an `int` is passed the columns default to 'dynamic' datatype: Datatype of values in sheet. Can be provided per column as a list of strings, or for the entire sheet as a single string. Valid datatypes are "str", "number", "bool", "date", and "markdown". @@ -150,24 +153,6 @@ def __init__( "Polars is not installed. Please install using `pip install polars`." ) self.type = type - values = { - "str": "", - "number": 0, - "bool": False, - "date": "01/01/1970", - "markdown": "", - "html": "", - } - column_dtypes = ( - [datatype] * self.col_count[0] if isinstance(datatype, str) else datatype - ) - self.empty_input = { - "headers": self.headers, - "data": [ - [values[c] for c in column_dtypes] for _ in range(self.row_count[0]) - ], - "metadata": None, - } if latex_delimiters is None: latex_delimiters = [{"left": "$$", "right": "$$", "display": True}] @@ -235,7 +220,7 @@ def preprocess( ) @staticmethod - def _is_empty( + def is_empty( value: pd.DataFrame | Styler | np.ndarray @@ -246,9 +231,14 @@ def _is_empty( | str | None, ) -> bool: + """ + Checks if the value of the dataframe provided is empty. + """ import pandas as pd from pandas.io.formats.style import Styler + if value is None: + return True if isinstance(value, pd.DataFrame): return value.empty elif isinstance(value, Styler): @@ -257,13 +247,17 @@ def _is_empty( return value.size == 0 elif _is_polars_available() and isinstance(value, _import_polars().DataFrame): return value.is_empty() - elif isinstance(value, list) and len(value) and isinstance(value[0], list): - return len(value[0]) == 0 - elif isinstance(value, (list, dict)): + elif isinstance(value, list): + if len(value) > 0 and isinstance(value[0], list): + return len(value[0]) == 0 + return len(value) == 0 + elif isinstance(value, dict): + if "data" in value: + return len(value["data"]) == 0 return len(value) == 0 return False - def postprocess( + def get_headers( self, value: pd.DataFrame | Styler @@ -274,102 +268,153 @@ def postprocess( | dict | str | None, - ) -> DataframeData: + ) -> list[str]: """ - Parameters: - value: Expects data in any of these formats: `pandas.DataFrame`, `pandas.Styler`, `numpy.array`, `polars.DataFrame`, `list[list]`, `list`, or a `dict` with keys 'data' (and optionally 'headers'), or `str` path to a csv, which is rendered as the spreadsheet. - Returns: - the uploaded spreadsheet data as an object with `headers` and `data` keys and optional `metadata` key + Returns the headers of the dataframes based on the value provided. For values + that do not have headers, an empty list is returned. """ import pandas as pd from pandas.io.formats.style import Styler - if isinstance(value, Styler) and semantic_version.Version( - pd.__version__ - ) < semantic_version.Version("1.5.0"): - raise ValueError( - "Styler objects are only supported in pandas version 1.5.0 or higher. Please try: `pip install --upgrade pandas` to use this feature." - ) + if value is None: + return [] + if isinstance(value, pd.DataFrame): + return list(value.columns) + elif isinstance(value, Styler): + return list(value.data.columns) # type: ignore + elif isinstance(value, str): + return list(pd.read_csv(value).columns) + elif _is_polars_available() and isinstance(value, _import_polars().DataFrame): + return list(value.columns) + elif isinstance(value, dict): + return value.get("headers", []) + elif isinstance(value, (list, np.ndarray)): + return [] + return [] + + @staticmethod + def get_cell_data( + value: pd.DataFrame + | Styler + | np.ndarray + | pl.DataFrame + | list + | list[list] + | dict + | str + | None, + ) -> list[list[Any]]: + """ + Gets the cell data (as a list of lists) from the value provided. + """ + import pandas as pd + from pandas.io.formats.style import Styler - if value is None or self._is_empty(value): - return DataframeData( - headers=self.headers, data=[["" for _ in range(len(self.headers))]] - ) if isinstance(value, dict): - if len(value) == 0: - return DataframeData( - headers=self.headers, data=[["" for _ in range(len(self.headers))]] - ) - return DataframeData( - headers=value.get("headers", []), data=value.get("data", [[]]) - ) + return value.get("data", [[]]) if isinstance(value, (str, pd.DataFrame)): if isinstance(value, str): value = pd.read_csv(value) # type: ignore - if len(value) == 0: - return DataframeData( - headers=[str(col) for col in value.columns], # Convert to strings - data=[["" for _ in range(len(value.columns))]], - ) - return DataframeData( - headers=[str(col) for col in value.columns], - data=value.to_dict(orient="split")["data"], - ) + return value.to_dict(orient="split")["data"] elif isinstance(value, Styler): - if self.interactive: - warnings.warn( - "Cannot display Styler object in interactive mode. Will display as a regular pandas dataframe instead." - ) df: pd.DataFrame = value.data # type: ignore + hidden_columns = getattr(value, "hidden_columns", []) visible_cols = [ - i - for i, col in enumerate(df.columns) - if i not in getattr(value, "hidden_columns", []) + i for i, _ in enumerate(df.columns) if i not in hidden_columns ] df = df.iloc[:, visible_cols] - - if len(df) == 0: - return DataframeData( - headers=list(df.columns), - data=[["" for _ in range(len(df.columns))]], - metadata=self.__extract_metadata( - value, getattr(value, "hidden_columns", []) - ), # type: ignore - ) - return DataframeData( - headers=list(df.columns), - data=df.to_dict(orient="split")["data"], # type: ignore - metadata=self.__extract_metadata( - value, getattr(value, "hidden_columns", []) - ), # type: ignore - ) + return df.to_dict(orient="split")["data"] elif _is_polars_available() and isinstance(value, _import_polars().DataFrame): - if len(value) == 0: - return DataframeData(headers=list(value.to_dict().keys()), data=[[]]) # type: ignore df_dict = value.to_dict() # type: ignore - headers = list(df_dict.keys()) data = list(zip(*df_dict.values())) - return DataframeData(headers=headers, data=data) + return data elif isinstance(value, (np.ndarray, list)): - if len(value) == 0: - return DataframeData(headers=self.headers, data=[[]]) if isinstance(value, np.ndarray): value = value.tolist() if not isinstance(value, list): raise ValueError("output cannot be converted to list") + if not isinstance(value[0], list): + return [[v] for v in value] + return value + else: + raise ValueError( + f"Cannot process value of type {type(value)} in gr.Dataframe" + ) - _headers = self.headers - if len(self.headers) < len(value[0]): - _headers: list[str] = [ - *self.headers, - *[str(i) for i in range(len(self.headers) + 1, len(value[0]) + 1)], - ] - elif len(self.headers) > len(value[0]): - _headers = self.headers[: len(value[0])] + @staticmethod + def get_metadata( + value: pd.DataFrame + | Styler + | np.ndarray + | pl.DataFrame + | list + | list[list] + | dict + | str + | None, + ) -> dict[str, list[list]] | None: + """ + Gets the metadata from the value provided. + """ + from pandas.io.formats.style import Styler - return DataframeData(headers=_headers, data=value) - else: - raise ValueError("Cannot process value as a Dataframe") + if isinstance(value, Styler): + return Dataframe.__extract_metadata( + value, getattr(value, "hidden_columns", []) + ) + return None + + def postprocess( + self, + value: pd.DataFrame + | Styler + | np.ndarray + | pl.DataFrame + | list + | list[list] + | dict + | str + | None, + ) -> DataframeData: + """ + Parameters: + value: Expects data in any of these formats: `pandas.DataFrame`, `pandas.Styler`, `numpy.array`, `polars.DataFrame`, `list[list]`, `list`, or a `dict` with keys 'data' (and optionally 'headers'), or `str` path to a csv, which is rendered as the spreadsheet. + Returns: + the uploaded spreadsheet data as an object with `headers` and `data` keys and optional `metadata` key + """ + import pandas as pd + from pandas.io.formats.style import Styler + + if isinstance(value, Styler) and semantic_version.Version( + pd.__version__ + ) < semantic_version.Version("1.5.0"): + raise ValueError( + "Styler objects are only supported in pandas version 1.5.0 or higher. Please try: `pip install --upgrade pandas` to use this feature." + ) + if isinstance(value, Styler) and self.interactive: + warnings.warn( + "Cannot display Styler object in interactive mode. Will display as a regular pandas dataframe instead." + ) + + headers = self.get_headers(value) or self.headers + data = ( + [["" for _ in range(len(headers))]] + if self.is_empty(value) + else self.get_cell_data(value) + ) + if len(headers) > len(data[0]): + headers = headers[: len(data[0])] + elif len(headers) < len(data[0]): + headers = [ + *headers, + *[str(i) for i in range(len(headers) + 1, len(data[0]) + 1)], + ] + metadata = self.get_metadata(value) + return DataframeData( + headers=headers, + data=data, + metadata=metadata, # type: ignore + ) @staticmethod def __get_cell_style(cell_id: str, cell_styles: list[dict]) -> str: diff --git a/gradio/templates.py b/gradio/templates.py index 71f05b2277323..1cbabe30fd29e 100644 --- a/gradio/templates.py +++ b/gradio/templates.py @@ -579,7 +579,10 @@ def __init__( headers: list[str] | None = None, row_count: int | tuple[int, str] = (1, "dynamic"), col_count: int | tuple[int, str] | None = None, - datatype: str | list[str] = "str", + datatype: Literal["str", "number", "bool", "date", "markdown", "html"] + | Sequence[ + Literal["str", "number", "bool", "date", "markdown", "html"] + ] = "str", type: Literal["numpy"] = "numpy", latex_delimiters: list[dict[str, str | bool]] | None = None, label: str | None = None, @@ -649,7 +652,10 @@ def __init__( headers: list[str] | None = None, row_count: int | tuple[int, str] = (1, "dynamic"), col_count: int | tuple[int, str] | None = None, - datatype: str | list[str] = "str", + datatype: Literal["str", "number", "bool", "date", "markdown", "html"] + | Sequence[ + Literal["str", "number", "bool", "date", "markdown", "html"] + ] = "str", type: Literal["array"] = "array", latex_delimiters: list[dict[str, str | bool]] | None = None, label: str | None = None, @@ -719,7 +725,10 @@ def __init__( headers: list[str] | None = None, row_count: int | tuple[int, str] = (1, "dynamic"), col_count: Literal[1] = 1, - datatype: str | list[str] = "str", + datatype: Literal["str", "number", "bool", "date", "markdown", "html"] + | Sequence[ + Literal["str", "number", "bool", "date", "markdown", "html"] + ] = "str", type: Literal["array"] = "array", latex_delimiters: list[dict[str, str | bool]] | None = None, label: str | None = None, diff --git a/test/components/test_dataframe.py b/test/components/test_dataframe.py index f50c29d8871e7..285b3d9378db3 100644 --- a/test/components/test_dataframe.py +++ b/test/components/test_dataframe.py @@ -355,3 +355,42 @@ def test_dataframe_hidden_columns(self): ], }, } + + def test_is_empty(self): + """Test is_empty method with various data types""" + df = gr.Dataframe() + assert df.is_empty([]) + assert df.is_empty([[]]) + assert df.is_empty(np.array([])) + assert df.is_empty(np.zeros((2, 0))) + assert df.is_empty(None) + assert df.is_empty({}) + assert df.is_empty({"data": [], "headers": ["a", "b"]}) + assert not df.is_empty({"data": [1, 2]}) + assert not df.is_empty([[1, 2], [3, 4]]) + assert not df.is_empty(pd.DataFrame({"a": [1, 2]})) + assert not df.is_empty(pd.DataFrame({"a": [1, 2]}).style) + + def test_get_headers(self): + """Test get_headers method with various data types""" + df = gr.Dataframe() + test_df = pd.DataFrame({"col1": [1, 2], "col2": [3, 4]}) + assert df.get_headers(test_df) == ["col1", "col2"] + assert df.get_headers(test_df.style) == ["col1", "col2"] + assert df.get_headers({"headers": ["a", "b"]}) == ["a", "b"] + assert df.get_headers(np.array([[1, 2], [3, 4]])) == [] + assert df.get_headers(None) == [] + + def test_get_cell_data(self): + """Test get_cell_data method with various data types""" + df = gr.Dataframe() + test_data = [[1, 2], [3, 4]] + test_df = pd.DataFrame({"col1": [1, 3], "col2": [2, 4]}) + assert df.get_cell_data(test_data) == [[1, 2], [3, 4]] + assert df.get_cell_data(test_df) == [[1, 2], [3, 4]] + assert df.get_cell_data({"data": test_data}) == [[1, 2], [3, 4]] + assert df.get_cell_data(np.array([1, 2, 3])) == [[1], [2], [3]] + + styled_df = test_df.style + styled_df.hide(axis=1, subset=["col2"]) + assert df.get_cell_data(styled_df) == [[1], [3]]