diff --git a/.flake8 b/.flake8 new file mode 100644 index 0000000000..1d36346c0d --- /dev/null +++ b/.flake8 @@ -0,0 +1,2 @@ +[flake8] +max-line-length = 88 \ No newline at end of file diff --git a/packages/python/plotly/plotly/express/_core.py b/packages/python/plotly/plotly/express/_core.py index 937ef9b5da..24135fa293 100644 --- a/packages/python/plotly/plotly/express/_core.py +++ b/packages/python/plotly/plotly/express/_core.py @@ -322,7 +322,6 @@ def make_trace_kwargs(args, trace_spec, trace_data, mapping_labels, sizeref): and args["y"] and len(trace_data[[args["x"], args["y"]]].dropna()) > 1 ): - # sorting is bad but trace_specs with "trendline" have no other attrs sorted_trace_data = trace_data.sort_values(by=args["x"]) y = sorted_trace_data[args["y"]].values @@ -563,7 +562,6 @@ def set_cartesian_axis_opts(args, axis, letter, orders): def configure_cartesian_marginal_axes(args, fig, orders): - if "histogram" in [args["marginal_x"], args["marginal_y"]]: fig.layout["barmode"] = "overlay" @@ -1065,14 +1063,14 @@ def _escape_col_name(columns, col_name, extra): return col_name -def to_unindexed_series(x): +def to_unindexed_series(x, name=None): """ assuming x is list-like or even an existing pd.Series, return a new pd.Series with no index, without extracting the data from an existing Series via numpy, which seems to mangle datetime columns. Stripping the index from existing pd.Series is required to get things to match up right in the new DataFrame we're building """ - return pd.Series(x).reset_index(drop=True) + return pd.Series(x, name=name).reset_index(drop=True) def process_args_into_dataframe(args, wide_mode, var_name, value_name): @@ -1087,9 +1085,12 @@ def process_args_into_dataframe(args, wide_mode, var_name, value_name): df_input = args["data_frame"] df_provided = df_input is not None - df_output = pd.DataFrame() - constants = dict() - ranges = list() + # we use a dict instead of a dataframe directly so that it doesn't cause + # PerformanceWarning by pandas by repeatedly setting the columns. + # a dict is used instead of a list as the columns needs to be overwritten. + df_output = {} + constants = {} + ranges = [] wide_id_vars = set() reserved_names = _get_reserved_col_names(args) if df_provided else set() @@ -1100,7 +1101,7 @@ def process_args_into_dataframe(args, wide_mode, var_name, value_name): "No data were provided. Please provide data either with the `data_frame` or with the `dimensions` argument." ) else: - df_output[df_input.columns] = df_input[df_input.columns] + df_output = {col: series for col, series in df_input.items()} # hover_data is a dict hover_data_is_dict = ( @@ -1141,7 +1142,7 @@ def process_args_into_dataframe(args, wide_mode, var_name, value_name): # argument_list and field_list ready, iterate over them # Core of the loop starts here for i, (argument, field) in enumerate(zip(argument_list, field_list)): - length = len(df_output) + length = len(df_output[next(iter(df_output))]) if len(df_output) else 0 if argument is None: continue col_name = None @@ -1182,11 +1183,11 @@ def process_args_into_dataframe(args, wide_mode, var_name, value_name): % ( argument, len(real_argument), - str(list(df_output.columns)), + str(list(df_output.keys())), length, ) ) - df_output[col_name] = to_unindexed_series(real_argument) + df_output[col_name] = to_unindexed_series(real_argument, col_name) elif not df_provided: raise ValueError( "String or int arguments are only possible when a " @@ -1215,13 +1216,15 @@ def process_args_into_dataframe(args, wide_mode, var_name, value_name): % ( field, len(df_input[argument]), - str(list(df_output.columns)), + str(list(df_output.keys())), length, ) ) else: col_name = str(argument) - df_output[col_name] = to_unindexed_series(df_input[argument]) + df_output[col_name] = to_unindexed_series( + df_input[argument], col_name + ) # ----------------- argument is likely a column / array / list.... ------- else: if df_provided and hasattr(argument, "name"): @@ -1248,9 +1251,9 @@ def process_args_into_dataframe(args, wide_mode, var_name, value_name): "All arguments should have the same length. " "The length of argument `%s` is %d, whereas the " "length of previously-processed arguments %s is %d" - % (field, len(argument), str(list(df_output.columns)), length) + % (field, len(argument), str(list(df_output.keys())), length) ) - df_output[str(col_name)] = to_unindexed_series(argument) + df_output[str(col_name)] = to_unindexed_series(argument, str(col_name)) # Finally, update argument with column name now that column exists assert col_name is not None, ( @@ -1268,12 +1271,19 @@ def process_args_into_dataframe(args, wide_mode, var_name, value_name): if field_name != "wide_variable": wide_id_vars.add(str(col_name)) - for col_name in ranges: - df_output[col_name] = range(len(df_output)) - - for col_name in constants: - df_output[col_name] = constants[col_name] + length = len(df_output[next(iter(df_output))]) if len(df_output) else 0 + df_output.update( + {col_name: to_unindexed_series(range(length), col_name) for col_name in ranges} + ) + df_output.update( + { + # constant is single value. repeat by len to avoid creating NaN on concating + col_name: to_unindexed_series([constants[col_name]] * length, col_name) + for col_name in constants + } + ) + df_output = pd.DataFrame(df_output) return df_output, wide_id_vars diff --git a/packages/python/plotly/plotly/tests/test_optional/test_px/test_px_wide.py b/packages/python/plotly/plotly/tests/test_optional/test_px/test_px_wide.py index 9aef665760..1aac7b70ea 100644 --- a/packages/python/plotly/plotly/tests/test_optional/test_px/test_px_wide.py +++ b/packages/python/plotly/plotly/tests/test_optional/test_px/test_px_wide.py @@ -1,9 +1,11 @@ import plotly.express as px import plotly.graph_objects as go import pandas as pd +import numpy as np from plotly.express._core import build_dataframe, _is_col_list from pandas.testing import assert_frame_equal import pytest +import warnings def test_is_col_list(): @@ -847,3 +849,29 @@ def test_line_group(): assert len(fig.data) == 4 fig = px.scatter(df, x="x", y=["miss", "score"], color="who") assert len(fig.data) == 2 + + +def test_no_pd_perf_warning(): + n_cols = 1000 + n_rows = 1000 + + columns = list(f"col_{c}" for c in range(n_cols)) + index = list(f"i_{r}" for r in range(n_rows)) + + df = pd.DataFrame( + np.random.uniform(size=(n_rows, n_cols)), index=index, columns=columns + ) + + with warnings.catch_warnings(record=True) as warn_list: + _ = px.bar( + df, + x=df.index, + y=df.columns[:-2], + labels=df.columns[:-2], + ) + performance_warnings = [ + warn + for warn in warn_list + if issubclass(warn.category, pd.errors.PerformanceWarning) + ] + assert len(performance_warnings) == 0, "PerformanceWarning(s) raised!"