From 2dacf6b5beb846bab1c46b6939388bedf1ccd6d1 Mon Sep 17 00:00:00 2001 From: Yongjie Zhao Date: Thu, 10 Feb 2022 15:12:44 +0800 Subject: [PATCH 1/5] wip --- superset/utils/pandas_postprocessing.py | 1002 ----------------- .../utils/pandas_postprocessing/__init__.py | 17 + .../utils/pandas_postprocessing/aggregate.py | 47 + .../utils/pandas_postprocessing/boxplot.py | 127 +++ .../utils/pandas_postprocessing/compare.py | 81 ++ .../pandas_postprocessing/contribution.py | 77 ++ superset/utils/pandas_postprocessing/cum.py | 68 ++ superset/utils/pandas_postprocessing/diff.py | 41 + .../utils/pandas_postprocessing/geography.py | 104 ++ superset/utils/pandas_postprocessing/pivot.py | 112 ++ .../utils/pandas_postprocessing/prophet.py | 147 +++ .../utils/pandas_postprocessing/resample.py | 55 + .../utils/pandas_postprocessing/rolling.py | 102 ++ .../utils/pandas_postprocessing/select.py | 48 + superset/utils/pandas_postprocessing/sort.py | 30 + superset/utils/pandas_postprocessing/utils.py | 215 ++++ 16 files changed, 1271 insertions(+), 1002 deletions(-) delete mode 100644 superset/utils/pandas_postprocessing.py create mode 100644 superset/utils/pandas_postprocessing/__init__.py create mode 100644 superset/utils/pandas_postprocessing/aggregate.py create mode 100644 superset/utils/pandas_postprocessing/boxplot.py create mode 100644 superset/utils/pandas_postprocessing/compare.py create mode 100644 superset/utils/pandas_postprocessing/contribution.py create mode 100644 superset/utils/pandas_postprocessing/cum.py create mode 100644 superset/utils/pandas_postprocessing/diff.py create mode 100644 superset/utils/pandas_postprocessing/geography.py create mode 100644 superset/utils/pandas_postprocessing/pivot.py create mode 100644 superset/utils/pandas_postprocessing/prophet.py create mode 100644 superset/utils/pandas_postprocessing/resample.py create mode 100644 superset/utils/pandas_postprocessing/rolling.py create mode 100644 superset/utils/pandas_postprocessing/select.py create mode 100644 superset/utils/pandas_postprocessing/sort.py create mode 100644 superset/utils/pandas_postprocessing/utils.py diff --git a/superset/utils/pandas_postprocessing.py b/superset/utils/pandas_postprocessing.py deleted file mode 100644 index beef42a36ff26..0000000000000 --- a/superset/utils/pandas_postprocessing.py +++ /dev/null @@ -1,1002 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# pylint: disable=too-many-lines -import logging -from decimal import Decimal -from functools import partial -from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union - -import geohash as geohash_lib -import numpy as np -import pandas as pd -from flask_babel import gettext as _ -from geopy.point import Point -from pandas import DataFrame, NamedAgg, Series, Timestamp - -from superset.constants import NULL_STRING, PandasAxis, PandasPostprocessingCompare -from superset.exceptions import QueryObjectValidationError -from superset.utils.core import ( - DTTM_ALIAS, - PostProcessingBoxplotWhiskerType, - PostProcessingContributionOrientation, - TIME_COMPARISION, -) - -NUMPY_FUNCTIONS = { - "average": np.average, - "argmin": np.argmin, - "argmax": np.argmax, - "count": np.ma.count, - "count_nonzero": np.count_nonzero, - "cumsum": np.cumsum, - "cumprod": np.cumprod, - "max": np.max, - "mean": np.mean, - "median": np.median, - "nansum": np.nansum, - "nanmin": np.nanmin, - "nanmax": np.nanmax, - "nanmean": np.nanmean, - "nanmedian": np.nanmedian, - "nanpercentile": np.nanpercentile, - "min": np.min, - "percentile": np.percentile, - "prod": np.prod, - "product": np.product, - "std": np.std, - "sum": np.sum, - "var": np.var, -} - -DENYLIST_ROLLING_FUNCTIONS = ( - "count", - "corr", - "cov", - "kurt", - "max", - "mean", - "median", - "min", - "std", - "skew", - "sum", - "var", - "quantile", -) - -ALLOWLIST_CUMULATIVE_FUNCTIONS = ( - "cummax", - "cummin", - "cumprod", - "cumsum", -) - -PROPHET_TIME_GRAIN_MAP = { - "PT1S": "S", - "PT1M": "min", - "PT5M": "5min", - "PT10M": "10min", - "PT15M": "15min", - "PT30M": "30min", - "PT1H": "H", - "P1D": "D", - "P1W": "W", - "P1M": "M", - "P3M": "Q", - "P1Y": "A", - "1969-12-28T00:00:00Z/P1W": "W", - "1969-12-29T00:00:00Z/P1W": "W", - "P1W/1970-01-03T00:00:00Z": "W", - "P1W/1970-01-04T00:00:00Z": "W", -} - - -def _flatten_column_after_pivot( - column: Union[float, Timestamp, str, Tuple[str, ...]], - aggregates: Dict[str, Dict[str, Any]], -) -> str: - """ - Function for flattening column names into a single string. This step is necessary - to be able to properly serialize a DataFrame. If the column is a string, return - element unchanged. For multi-element columns, join column elements with a comma, - with the exception of pivots made with a single aggregate, in which case the - aggregate column name is omitted. - - :param column: single element from `DataFrame.columns` - :param aggregates: aggregates - :return: - """ - if not isinstance(column, tuple): - column = (column,) - if len(aggregates) == 1 and len(column) > 1: - # drop aggregate for single aggregate pivots with multiple groupings - # from column name (aggregates always come first in column name) - column = column[1:] - return ", ".join([str(col) for col in column]) - - -def validate_column_args(*argnames: str) -> Callable[..., Any]: - def wrapper(func: Callable[..., Any]) -> Callable[..., Any]: - def wrapped(df: DataFrame, **options: Any) -> Any: - if options.get("is_pivot_df"): - # skip validation when pivot Dataframe - return func(df, **options) - columns = df.columns.tolist() - for name in argnames: - if name in options and not all( - elem in columns for elem in options.get(name) or [] - ): - raise QueryObjectValidationError( - _("Referenced columns not available in DataFrame.") - ) - return func(df, **options) - - return wrapped - - return wrapper - - -def _get_aggregate_funcs( - df: DataFrame, aggregates: Dict[str, Dict[str, Any]], -) -> Dict[str, NamedAgg]: - """ - Converts a set of aggregate config objects into functions that pandas can use as - aggregators. Currently only numpy aggregators are supported. - - :param df: DataFrame on which to perform aggregate operation. - :param aggregates: Mapping from column name to aggregate config. - :return: Mapping from metric name to function that takes a single input argument. - """ - agg_funcs: Dict[str, NamedAgg] = {} - for name, agg_obj in aggregates.items(): - column = agg_obj.get("column", name) - if column not in df: - raise QueryObjectValidationError( - _( - "Column referenced by aggregate is undefined: %(column)s", - column=column, - ) - ) - if "operator" not in agg_obj: - raise QueryObjectValidationError( - _("Operator undefined for aggregator: %(name)s", name=name,) - ) - operator = agg_obj["operator"] - if callable(operator): - aggfunc = operator - else: - func = NUMPY_FUNCTIONS.get(operator) - if not func: - raise QueryObjectValidationError( - _("Invalid numpy function: %(operator)s", operator=operator,) - ) - options = agg_obj.get("options", {}) - aggfunc = partial(func, **options) - agg_funcs[name] = NamedAgg(column=column, aggfunc=aggfunc) - - return agg_funcs - - -def _append_columns( - base_df: DataFrame, append_df: DataFrame, columns: Dict[str, str] -) -> DataFrame: - """ - Function for adding columns from one DataFrame to another DataFrame. Calls the - assign method, which overwrites the original column in `base_df` if the column - already exists, and appends the column if the name is not defined. - - :param base_df: DataFrame which to use as the base - :param append_df: DataFrame from which to select data. - :param columns: columns on which to append, mapping source column to - target column. For instance, `{'y': 'y'}` will replace the values in - column `y` in `base_df` with the values in `y` in `append_df`, - while `{'y': 'y2'}` will add a column `y2` to `base_df` based - on values in column `y` in `append_df`, leaving the original column `y` - in `base_df` unchanged. - :return: new DataFrame with combined data from `base_df` and `append_df` - """ - return base_df.assign( - **{target: append_df[source] for source, target in columns.items()} - ) - - -@validate_column_args("index", "columns") -def pivot( # pylint: disable=too-many-arguments,too-many-locals - df: DataFrame, - index: List[str], - aggregates: Dict[str, Dict[str, Any]], - columns: Optional[List[str]] = None, - metric_fill_value: Optional[Any] = None, - column_fill_value: Optional[str] = NULL_STRING, - drop_missing_columns: Optional[bool] = True, - combine_value_with_metric: bool = False, - marginal_distributions: Optional[bool] = None, - marginal_distribution_name: Optional[str] = None, - flatten_columns: bool = True, - reset_index: bool = True, -) -> DataFrame: - """ - Perform a pivot operation on a DataFrame. - - :param df: Object on which pivot operation will be performed - :param index: Columns to group by on the table index (=rows) - :param columns: Columns to group by on the table columns - :param metric_fill_value: Value to replace missing values with - :param column_fill_value: Value to replace missing pivot columns with. By default - replaces missing values with "". Set to `None` to remove columns - with missing values. - :param drop_missing_columns: Do not include columns whose entries are all missing - :param combine_value_with_metric: Display metrics side by side within each column, - as opposed to each column being displayed side by side for each metric. - :param aggregates: A mapping from aggregate column name to the the aggregate - config. - :param marginal_distributions: Add totals for row/column. Default to False - :param marginal_distribution_name: Name of row/column with marginal distribution. - Default to 'All'. - :param flatten_columns: Convert column names to strings - :param reset_index: Convert index to column - :return: A pivot table - :raises QueryObjectValidationError: If the request in incorrect - """ - if not index: - raise QueryObjectValidationError( - _("Pivot operation requires at least one index") - ) - if not aggregates: - raise QueryObjectValidationError( - _("Pivot operation must include at least one aggregate") - ) - - if columns and column_fill_value: - df[columns] = df[columns].fillna(value=column_fill_value) - - aggregate_funcs = _get_aggregate_funcs(df, aggregates) - - # TODO (villebro): Pandas 1.0.3 doesn't yet support NamedAgg in pivot_table. - # Remove once/if support is added. - aggfunc = {na.column: na.aggfunc for na in aggregate_funcs.values()} - - # When dropna = False, the pivot_table function will calculate cartesian-product - # for MultiIndex. - # https://github.com/apache/superset/issues/15956 - # https://github.com/pandas-dev/pandas/issues/18030 - series_set = set() - if not drop_missing_columns and columns: - for row in df[columns].itertuples(): - for metric in aggfunc.keys(): - series_set.add(str(tuple([metric]) + tuple(row[1:]))) - - df = df.pivot_table( - values=aggfunc.keys(), - index=index, - columns=columns, - aggfunc=aggfunc, - fill_value=metric_fill_value, - dropna=drop_missing_columns, - margins=marginal_distributions, - margins_name=marginal_distribution_name, - ) - - if not drop_missing_columns and len(series_set) > 0 and not df.empty: - for col in df.columns: - series = str(col) - if series not in series_set: - df = df.drop(col, axis=PandasAxis.COLUMN) - - if combine_value_with_metric: - df = df.stack(0).unstack() - - # Make index regular column - if flatten_columns: - df.columns = [ - _flatten_column_after_pivot(col, aggregates) for col in df.columns - ] - # return index as regular column - if reset_index: - df.reset_index(level=0, inplace=True) - return df - - -@validate_column_args("groupby") -def aggregate( - df: DataFrame, groupby: List[str], aggregates: Dict[str, Dict[str, Any]] -) -> DataFrame: - """ - Apply aggregations to a DataFrame. - - :param df: Object to aggregate. - :param groupby: columns to aggregate - :param aggregates: A mapping from metric column to the function used to - aggregate values. - :raises QueryObjectValidationError: If the request in incorrect - """ - aggregates = aggregates or {} - aggregate_funcs = _get_aggregate_funcs(df, aggregates) - if groupby: - df_groupby = df.groupby(by=groupby) - else: - df_groupby = df.groupby(lambda _: True) - return df_groupby.agg(**aggregate_funcs).reset_index(drop=not groupby) - - -@validate_column_args("columns") -def sort(df: DataFrame, columns: Dict[str, bool]) -> DataFrame: - """ - Sort a DataFrame. - - :param df: DataFrame to sort. - :param columns: columns by by which to sort. The key specifies the column name, - value specifies if sorting in ascending order. - :return: Sorted DataFrame - :raises QueryObjectValidationError: If the request in incorrect - """ - return df.sort_values(by=list(columns.keys()), ascending=list(columns.values())) - - -@validate_column_args("columns") -def rolling( # pylint: disable=too-many-arguments - df: DataFrame, - rolling_type: str, - columns: Optional[Dict[str, str]] = None, - window: Optional[int] = None, - rolling_type_options: Optional[Dict[str, Any]] = None, - center: bool = False, - win_type: Optional[str] = None, - min_periods: Optional[int] = None, - is_pivot_df: bool = False, -) -> DataFrame: - """ - Apply a rolling window on the dataset. See the Pandas docs for further details: - https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.DataFrame.rolling.html - - :param df: DataFrame on which the rolling period will be based. - :param columns: columns on which to perform rolling, mapping source column to - target column. For instance, `{'y': 'y'}` will replace the column `y` with - the rolling value in `y`, while `{'y': 'y2'}` will add a column `y2` based - on rolling values calculated from `y`, leaving the original column `y` - unchanged. - :param rolling_type: Type of rolling window. Any numpy function will work. - :param window: Size of the window. - :param rolling_type_options: Optional options to pass to rolling method. Needed - for e.g. quantile operation. - :param center: Should the label be at the center of the window. - :param win_type: Type of window function. - :param min_periods: The minimum amount of periods required for a row to be included - in the result set. - :param is_pivot_df: Dataframe is pivoted or not - :return: DataFrame with the rolling columns - :raises QueryObjectValidationError: If the request in incorrect - """ - rolling_type_options = rolling_type_options or {} - columns = columns or {} - if is_pivot_df: - df_rolling = df - else: - df_rolling = df[columns.keys()] - kwargs: Dict[str, Union[str, int]] = {} - if window is None: - raise QueryObjectValidationError(_("Undefined window for rolling operation")) - if window == 0: - raise QueryObjectValidationError(_("Window must be > 0")) - - kwargs["window"] = window - if min_periods is not None: - kwargs["min_periods"] = min_periods - if center is not None: - kwargs["center"] = center - if win_type is not None: - kwargs["win_type"] = win_type - - df_rolling = df_rolling.rolling(**kwargs) - if rolling_type not in DENYLIST_ROLLING_FUNCTIONS or not hasattr( - df_rolling, rolling_type - ): - raise QueryObjectValidationError( - _("Invalid rolling_type: %(type)s", type=rolling_type) - ) - try: - df_rolling = getattr(df_rolling, rolling_type)(**rolling_type_options) - except TypeError as ex: - raise QueryObjectValidationError( - _( - "Invalid options for %(rolling_type)s: %(options)s", - rolling_type=rolling_type, - options=rolling_type_options, - ) - ) from ex - - if is_pivot_df: - agg_in_pivot_df = df.columns.get_level_values(0).drop_duplicates().to_list() - agg: Dict[str, Dict[str, Any]] = {col: {} for col in agg_in_pivot_df} - df_rolling.columns = [ - _flatten_column_after_pivot(col, agg) for col in df_rolling.columns - ] - df_rolling.reset_index(level=0, inplace=True) - else: - df_rolling = _append_columns(df, df_rolling, columns) - - if min_periods: - df_rolling = df_rolling[min_periods:] - return df_rolling - - -@validate_column_args("columns", "drop", "rename") -def select( - df: DataFrame, - columns: Optional[List[str]] = None, - exclude: Optional[List[str]] = None, - rename: Optional[Dict[str, str]] = None, -) -> DataFrame: - """ - Only select a subset of columns in the original dataset. Can be useful for - removing unnecessary intermediate results, renaming and reordering columns. - - :param df: DataFrame on which the rolling period will be based. - :param columns: Columns which to select from the DataFrame, in the desired order. - If left undefined, all columns will be selected. If columns are - renamed, the original column name should be referenced here. - :param exclude: columns to exclude from selection. If columns are renamed, the new - column name should be referenced here. - :param rename: columns which to rename, mapping source column to target column. - For instance, `{'y': 'y2'}` will rename the column `y` to - `y2`. - :return: Subset of columns in original DataFrame - :raises QueryObjectValidationError: If the request in incorrect - """ - df_select = df.copy(deep=False) - if columns: - df_select = df_select[columns] - if exclude: - df_select = df_select.drop(exclude, axis=1) - if rename is not None: - df_select = df_select.rename(columns=rename) - return df_select - - -@validate_column_args("columns") -def diff( - df: DataFrame, - columns: Dict[str, str], - periods: int = 1, - axis: PandasAxis = PandasAxis.ROW, -) -> DataFrame: - """ - Calculate row-by-row or column-by-column difference for select columns. - - :param df: DataFrame on which the diff will be based. - :param columns: columns on which to perform diff, mapping source column to - target column. For instance, `{'y': 'y'}` will replace the column `y` with - the diff value in `y`, while `{'y': 'y2'}` will add a column `y2` based - on diff values calculated from `y`, leaving the original column `y` - unchanged. - :param periods: periods to shift for calculating difference. - :param axis: 0 for row, 1 for column. default 0. - :return: DataFrame with diffed columns - :raises QueryObjectValidationError: If the request in incorrect - """ - df_diff = df[columns.keys()] - df_diff = df_diff.diff(periods=periods, axis=axis) - return _append_columns(df, df_diff, columns) - - -@validate_column_args("source_columns", "compare_columns") -def compare( # pylint: disable=too-many-arguments - df: DataFrame, - source_columns: List[str], - compare_columns: List[str], - compare_type: Optional[PandasPostprocessingCompare], - drop_original_columns: Optional[bool] = False, - precision: Optional[int] = 4, -) -> DataFrame: - """ - Calculate column-by-column changing for select columns. - - :param df: DataFrame on which the compare will be based. - :param source_columns: Main query columns - :param compare_columns: Columns being compared - :param compare_type: Type of compare. Choice of `absolute`, `percentage` or `ratio` - :param drop_original_columns: Whether to remove the source columns and - compare columns. - :param precision: Round a change rate to a variable number of decimal places. - :return: DataFrame with compared columns. - :raises QueryObjectValidationError: If the request in incorrect. - """ - if len(source_columns) != len(compare_columns): - raise QueryObjectValidationError( - _("`compare_columns` must have the same length as `source_columns`.") - ) - if compare_type not in tuple(PandasPostprocessingCompare): - raise QueryObjectValidationError( - _("`compare_type` must be `difference`, `percentage` or `ratio`") - ) - if len(source_columns) == 0: - return df - - for s_col, c_col in zip(source_columns, compare_columns): - if compare_type == PandasPostprocessingCompare.DIFF: - diff_series = df[s_col] - df[c_col] - elif compare_type == PandasPostprocessingCompare.PCT: - diff_series = ( - ((df[s_col] - df[c_col]) / df[c_col]).astype(float).round(precision) - ) - else: - # compare_type == "ratio" - diff_series = (df[s_col] / df[c_col]).astype(float).round(precision) - diff_df = diff_series.to_frame( - name=TIME_COMPARISION.join([compare_type, s_col, c_col]) - ) - df = pd.concat([df, diff_df], axis=1) - - if drop_original_columns: - df = df.drop(source_columns + compare_columns, axis=1) - return df - - -@validate_column_args("columns") -def cum( - df: DataFrame, - operator: str, - columns: Optional[Dict[str, str]] = None, - is_pivot_df: bool = False, -) -> DataFrame: - """ - Calculate cumulative sum/product/min/max for select columns. - - :param df: DataFrame on which the cumulative operation will be based. - :param columns: columns on which to perform a cumulative operation, mapping source - column to target column. For instance, `{'y': 'y'}` will replace the column - `y` with the cumulative value in `y`, while `{'y': 'y2'}` will add a column - `y2` based on cumulative values calculated from `y`, leaving the original - column `y` unchanged. - :param operator: cumulative operator, e.g. `sum`, `prod`, `min`, `max` - :param is_pivot_df: Dataframe is pivoted or not - :return: DataFrame with cumulated columns - """ - columns = columns or {} - if is_pivot_df: - df_cum = df - else: - df_cum = df[columns.keys()] - operation = "cum" + operator - if operation not in ALLOWLIST_CUMULATIVE_FUNCTIONS or not hasattr( - df_cum, operation - ): - raise QueryObjectValidationError( - _("Invalid cumulative operator: %(operator)s", operator=operator) - ) - if is_pivot_df: - df_cum = getattr(df_cum, operation)() - agg_in_pivot_df = df.columns.get_level_values(0).drop_duplicates().to_list() - agg: Dict[str, Dict[str, Any]] = {col: {} for col in agg_in_pivot_df} - df_cum.columns = [ - _flatten_column_after_pivot(col, agg) for col in df_cum.columns - ] - df_cum.reset_index(level=0, inplace=True) - else: - df_cum = _append_columns(df, getattr(df_cum, operation)(), columns) - return df_cum - - -def geohash_decode( - df: DataFrame, geohash: str, longitude: str, latitude: str -) -> DataFrame: - """ - Decode a geohash column into longitude and latitude - - :param df: DataFrame containing geohash data - :param geohash: Name of source column containing geohash location. - :param longitude: Name of new column to be created containing longitude. - :param latitude: Name of new column to be created containing latitude. - :return: DataFrame with decoded longitudes and latitudes - """ - try: - lonlat_df = DataFrame() - lonlat_df["latitude"], lonlat_df["longitude"] = zip( - *df[geohash].apply(geohash_lib.decode) - ) - return _append_columns( - df, lonlat_df, {"latitude": latitude, "longitude": longitude} - ) - except ValueError as ex: - raise QueryObjectValidationError(_("Invalid geohash string")) from ex - - -def geohash_encode( - df: DataFrame, geohash: str, longitude: str, latitude: str, -) -> DataFrame: - """ - Encode longitude and latitude into geohash - - :param df: DataFrame containing longitude and latitude data - :param geohash: Name of new column to be created containing geohash location. - :param longitude: Name of source column containing longitude. - :param latitude: Name of source column containing latitude. - :return: DataFrame with decoded longitudes and latitudes - """ - try: - encode_df = df[[latitude, longitude]] - encode_df.columns = ["latitude", "longitude"] - encode_df["geohash"] = encode_df.apply( - lambda row: geohash_lib.encode(row["latitude"], row["longitude"]), axis=1, - ) - return _append_columns(df, encode_df, {"geohash": geohash}) - except ValueError as ex: - raise QueryObjectValidationError(_("Invalid longitude/latitude")) from ex - - -def geodetic_parse( - df: DataFrame, - geodetic: str, - longitude: str, - latitude: str, - altitude: Optional[str] = None, -) -> DataFrame: - """ - Parse a column containing a geodetic point string - [Geopy](https://geopy.readthedocs.io/en/stable/#geopy.point.Point). - - :param df: DataFrame containing geodetic point data - :param geodetic: Name of source column containing geodetic point string. - :param longitude: Name of new column to be created containing longitude. - :param latitude: Name of new column to be created containing latitude. - :param altitude: Name of new column to be created containing altitude. - :return: DataFrame with decoded longitudes and latitudes - """ - - def _parse_location(location: str) -> Tuple[float, float, float]: - """ - Parse a string containing a geodetic point and return latitude, longitude - and altitude - """ - point = Point(location) - return point[0], point[1], point[2] - - try: - geodetic_df = DataFrame() - ( - geodetic_df["latitude"], - geodetic_df["longitude"], - geodetic_df["altitude"], - ) = zip(*df[geodetic].apply(_parse_location)) - columns = {"latitude": latitude, "longitude": longitude} - if altitude: - columns["altitude"] = altitude - return _append_columns(df, geodetic_df, columns) - except ValueError as ex: - raise QueryObjectValidationError(_("Invalid geodetic string")) from ex - - -@validate_column_args("columns") -def contribution( - df: DataFrame, - orientation: Optional[ - PostProcessingContributionOrientation - ] = PostProcessingContributionOrientation.COLUMN, - columns: Optional[List[str]] = None, - rename_columns: Optional[List[str]] = None, -) -> DataFrame: - """ - Calculate cell contibution to row/column total for numeric columns. - Non-numeric columns will be kept untouched. - - If `columns` are specified, only calculate contributions on selected columns. - - :param df: DataFrame containing all-numeric data (temporal column ignored) - :param columns: Columns to calculate values from. - :param rename_columns: The new labels for the calculated contribution columns. - The original columns will not be removed. - :param orientation: calculate by dividing cell with row/column total - :return: DataFrame with contributions. - """ - contribution_df = df.copy() - numeric_df = contribution_df.select_dtypes(include=["number", Decimal]) - # verify column selections - if columns: - numeric_columns = numeric_df.columns.tolist() - for col in columns: - if col not in numeric_columns: - raise QueryObjectValidationError( - _( - 'Column "%(column)s" is not numeric or does not ' - "exists in the query results.", - column=col, - ) - ) - columns = columns or numeric_df.columns - rename_columns = rename_columns or columns - if len(rename_columns) != len(columns): - raise QueryObjectValidationError( - _("`rename_columns` must have the same length as `columns`.") - ) - # limit to selected columns - numeric_df = numeric_df[columns] - axis = 0 if orientation == PostProcessingContributionOrientation.COLUMN else 1 - numeric_df = numeric_df / numeric_df.values.sum(axis=axis, keepdims=True) - contribution_df[rename_columns] = numeric_df - return contribution_df - - -def _prophet_parse_seasonality( - input_value: Optional[Union[bool, int]] -) -> Union[bool, str, int]: - if input_value is None: - return "auto" - if isinstance(input_value, bool): - return input_value - try: - return int(input_value) - except ValueError: - return input_value - - -def _prophet_fit_and_predict( # pylint: disable=too-many-arguments - df: DataFrame, - confidence_interval: float, - yearly_seasonality: Union[bool, str, int], - weekly_seasonality: Union[bool, str, int], - daily_seasonality: Union[bool, str, int], - periods: int, - freq: str, -) -> DataFrame: - """ - Fit a prophet model and return a DataFrame with predicted results. - """ - try: - # pylint: disable=import-error,import-outside-toplevel - from prophet import Prophet - - prophet_logger = logging.getLogger("prophet.plot") - prophet_logger.setLevel(logging.CRITICAL) - prophet_logger.setLevel(logging.NOTSET) - except ModuleNotFoundError as ex: - raise QueryObjectValidationError(_("`prophet` package not installed")) from ex - model = Prophet( - interval_width=confidence_interval, - yearly_seasonality=yearly_seasonality, - weekly_seasonality=weekly_seasonality, - daily_seasonality=daily_seasonality, - ) - if df["ds"].dt.tz: - df["ds"] = df["ds"].dt.tz_convert(None) - model.fit(df) - future = model.make_future_dataframe(periods=periods, freq=freq) - forecast = model.predict(future)[["ds", "yhat", "yhat_lower", "yhat_upper"]] - return forecast.join(df.set_index("ds"), on="ds").set_index(["ds"]) - - -def prophet( # pylint: disable=too-many-arguments - df: DataFrame, - time_grain: str, - periods: int, - confidence_interval: float, - yearly_seasonality: Optional[Union[bool, int]] = None, - weekly_seasonality: Optional[Union[bool, int]] = None, - daily_seasonality: Optional[Union[bool, int]] = None, - index: Optional[str] = None, -) -> DataFrame: - """ - Add forecasts to each series in a timeseries dataframe, along with confidence - intervals for the prediction. For each series, the operation creates three - new columns with the column name suffixed with the following values: - - - `__yhat`: the forecast for the given date - - `__yhat_lower`: the lower bound of the forecast for the given date - - `__yhat_upper`: the upper bound of the forecast for the given date - - - :param df: DataFrame containing all-numeric data (temporal column ignored) - :param time_grain: Time grain used to specify time period increments in prediction - :param periods: Time periods (in units of `time_grain`) to predict into the future - :param confidence_interval: Width of predicted confidence interval - :param yearly_seasonality: Should yearly seasonality be applied. - An integer value will specify Fourier order of seasonality. - :param weekly_seasonality: Should weekly seasonality be applied. - An integer value will specify Fourier order of seasonality, `None` will - automatically detect seasonality. - :param daily_seasonality: Should daily seasonality be applied. - An integer value will specify Fourier order of seasonality, `None` will - automatically detect seasonality. - :param index: the name of the column containing the x-axis data - :return: DataFrame with contributions, with temporal column at beginning if present - """ - index = index or DTTM_ALIAS - # validate inputs - if not time_grain: - raise QueryObjectValidationError(_("Time grain missing")) - if time_grain not in PROPHET_TIME_GRAIN_MAP: - raise QueryObjectValidationError( - _("Unsupported time grain: %(time_grain)s", time_grain=time_grain,) - ) - freq = PROPHET_TIME_GRAIN_MAP[time_grain] - # check type at runtime due to marhsmallow schema not being able to handle - # union types - if not isinstance(periods, int) or periods < 0: - raise QueryObjectValidationError(_("Periods must be a whole number")) - if not confidence_interval or confidence_interval <= 0 or confidence_interval >= 1: - raise QueryObjectValidationError( - _("Confidence interval must be between 0 and 1 (exclusive)") - ) - if index not in df.columns: - raise QueryObjectValidationError(_("DataFrame must include temporal column")) - if len(df.columns) < 2: - raise QueryObjectValidationError(_("DataFrame include at least one series")) - - target_df = DataFrame() - for column in [column for column in df.columns if column != index]: - fit_df = _prophet_fit_and_predict( - df=df[[index, column]].rename(columns={index: "ds", column: "y"}), - confidence_interval=confidence_interval, - yearly_seasonality=_prophet_parse_seasonality(yearly_seasonality), - weekly_seasonality=_prophet_parse_seasonality(weekly_seasonality), - daily_seasonality=_prophet_parse_seasonality(daily_seasonality), - periods=periods, - freq=freq, - ) - new_columns = [ - f"{column}__yhat", - f"{column}__yhat_lower", - f"{column}__yhat_upper", - f"{column}", - ] - fit_df.columns = new_columns - if target_df.empty: - target_df = fit_df - else: - for new_column in new_columns: - target_df = target_df.assign(**{new_column: fit_df[new_column]}) - target_df.reset_index(level=0, inplace=True) - return target_df.rename(columns={"ds": index}) - - -def boxplot( - df: DataFrame, - groupby: List[str], - metrics: List[str], - whisker_type: PostProcessingBoxplotWhiskerType, - percentiles: Optional[ - Union[List[Union[int, float]], Tuple[Union[int, float], Union[int, float]]] - ] = None, -) -> DataFrame: - """ - Calculate boxplot statistics. For each metric, the operation creates eight - new columns with the column name suffixed with the following values: - - - `__mean`: the mean - - `__median`: the median - - `__max`: the maximum value excluding outliers (see whisker type) - - `__min`: the minimum value excluding outliers (see whisker type) - - `__q1`: the median - - `__q1`: the first quartile (25th percentile) - - `__q3`: the third quartile (75th percentile) - - `__count`: count of observations - - `__outliers`: the values that fall outside the minimum/maximum value - (see whisker type) - - :param df: DataFrame containing all-numeric data (temporal column ignored) - :param groupby: The categories to group by (x-axis) - :param metrics: The metrics for which to calculate the distribution - :param whisker_type: The confidence level type - :return: DataFrame with boxplot statistics per groupby - """ - - def quartile1(series: Series) -> float: - return np.nanpercentile(series, 25, interpolation="midpoint") - - def quartile3(series: Series) -> float: - return np.nanpercentile(series, 75, interpolation="midpoint") - - if whisker_type == PostProcessingBoxplotWhiskerType.TUKEY: - - def whisker_high(series: Series) -> float: - upper_outer_lim = quartile3(series) + 1.5 * ( - quartile3(series) - quartile1(series) - ) - return series[series <= upper_outer_lim].max() - - def whisker_low(series: Series) -> float: - lower_outer_lim = quartile1(series) - 1.5 * ( - quartile3(series) - quartile1(series) - ) - return series[series >= lower_outer_lim].min() - - elif whisker_type == PostProcessingBoxplotWhiskerType.PERCENTILE: - if ( - not isinstance(percentiles, (list, tuple)) - or len(percentiles) != 2 - or not isinstance(percentiles[0], (int, float)) - or not isinstance(percentiles[1], (int, float)) - or percentiles[0] >= percentiles[1] - ): - raise QueryObjectValidationError( - _( - "percentiles must be a list or tuple with two numeric values, " - "of which the first is lower than the second value" - ) - ) - low, high = percentiles[0], percentiles[1] - - def whisker_high(series: Series) -> float: - return np.nanpercentile(series, high) - - def whisker_low(series: Series) -> float: - return np.nanpercentile(series, low) - - else: - whisker_high = np.max - whisker_low = np.min - - def outliers(series: Series) -> Set[float]: - above = series[series > whisker_high(series)] - below = series[series < whisker_low(series)] - return above.tolist() + below.tolist() - - operators: Dict[str, Callable[[Any], Any]] = { - "mean": np.mean, - "median": np.median, - "max": whisker_high, - "min": whisker_low, - "q1": quartile1, - "q3": quartile3, - "count": np.ma.count, - "outliers": outliers, - } - aggregates: Dict[str, Dict[str, Union[str, Callable[..., Any]]]] = { - f"{metric}__{operator_name}": {"column": metric, "operator": operator} - for operator_name, operator in operators.items() - for metric in metrics - } - return aggregate(df, groupby=groupby, aggregates=aggregates) - - -@validate_column_args("groupby_columns") -def resample( # pylint: disable=too-many-arguments - df: DataFrame, - rule: str, - method: str, - time_column: str, - groupby_columns: Optional[Tuple[Optional[str], ...]] = None, - fill_value: Optional[Union[float, int]] = None, -) -> DataFrame: - """ - support upsampling in resample - - :param df: DataFrame to resample. - :param rule: The offset string representing target conversion. - :param method: How to fill the NaN value after resample. - :param time_column: existing columns in DataFrame. - :param groupby_columns: columns except time_column in dataframe - :param fill_value: What values do fill missing. - :return: DataFrame after resample - :raises QueryObjectValidationError: If the request in incorrect - """ - - def _upsampling(_df: DataFrame) -> DataFrame: - _df = _df.set_index(time_column) - if method == "asfreq" and fill_value is not None: - return _df.resample(rule).asfreq(fill_value=fill_value) - return getattr(_df.resample(rule), method)() - - if groupby_columns: - df = ( - df.set_index(keys=list(groupby_columns)) - .groupby(by=list(groupby_columns)) - .apply(_upsampling) - ) - df = df.reset_index().set_index(time_column).sort_index() - else: - df = _upsampling(df) - return df.reset_index() diff --git a/superset/utils/pandas_postprocessing/__init__.py b/superset/utils/pandas_postprocessing/__init__.py new file mode 100644 index 0000000000000..245692337bc3f --- /dev/null +++ b/superset/utils/pandas_postprocessing/__init__.py @@ -0,0 +1,17 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + diff --git a/superset/utils/pandas_postprocessing/aggregate.py b/superset/utils/pandas_postprocessing/aggregate.py new file mode 100644 index 0000000000000..5dde70f998e2b --- /dev/null +++ b/superset/utils/pandas_postprocessing/aggregate.py @@ -0,0 +1,47 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from typing import Any, Dict, List +from pandas import DataFrame + +from superset.utils.pandas_postprocessing.utils import ( + validate_column_args, + _get_aggregate_funcs, +) + + +@validate_column_args("groupby") +def aggregate( + df: DataFrame, groupby: List[str], aggregates: Dict[str, Dict[str, Any]] +) -> DataFrame: + """ + Apply aggregations to a DataFrame. + + :param df: Object to aggregate. + :param groupby: columns to aggregate + :param aggregates: A mapping from metric column to the function used to + aggregate values. + :raises QueryObjectValidationError: If the request in incorrect + """ + aggregates = aggregates or {} + aggregate_funcs = _get_aggregate_funcs(df, aggregates) + if groupby: + df_groupby = df.groupby(by=groupby) + else: + df_groupby = df.groupby(lambda _: True) + return df_groupby.agg(**aggregate_funcs).reset_index(drop=not groupby) + + diff --git a/superset/utils/pandas_postprocessing/boxplot.py b/superset/utils/pandas_postprocessing/boxplot.py new file mode 100644 index 0000000000000..ce00ed479c762 --- /dev/null +++ b/superset/utils/pandas_postprocessing/boxplot.py @@ -0,0 +1,127 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union + +import numpy as np +from flask_babel import gettext as _ +from pandas import DataFrame, Series + +from superset.exceptions import QueryObjectValidationError +from superset.utils.core import ( + PostProcessingBoxplotWhiskerType, +) +from superset.utils.pandas_postprocessing.aggregate import aggregate + + +def boxplot( + df: DataFrame, + groupby: List[str], + metrics: List[str], + whisker_type: PostProcessingBoxplotWhiskerType, + percentiles: Optional[ + Union[List[Union[int, float]], Tuple[Union[int, float], Union[int, float]]] + ] = None, +) -> DataFrame: + """ + Calculate boxplot statistics. For each metric, the operation creates eight + new columns with the column name suffixed with the following values: + + - `__mean`: the mean + - `__median`: the median + - `__max`: the maximum value excluding outliers (see whisker type) + - `__min`: the minimum value excluding outliers (see whisker type) + - `__q1`: the median + - `__q1`: the first quartile (25th percentile) + - `__q3`: the third quartile (75th percentile) + - `__count`: count of observations + - `__outliers`: the values that fall outside the minimum/maximum value + (see whisker type) + + :param df: DataFrame containing all-numeric data (temporal column ignored) + :param groupby: The categories to group by (x-axis) + :param metrics: The metrics for which to calculate the distribution + :param whisker_type: The confidence level type + :return: DataFrame with boxplot statistics per groupby + """ + + def quartile1(series: Series) -> float: + return np.nanpercentile(series, 25, interpolation="midpoint") + + def quartile3(series: Series) -> float: + return np.nanpercentile(series, 75, interpolation="midpoint") + + if whisker_type == PostProcessingBoxplotWhiskerType.TUKEY: + + def whisker_high(series: Series) -> float: + upper_outer_lim = quartile3(series) + 1.5 * ( + quartile3(series) - quartile1(series) + ) + return series[series <= upper_outer_lim].max() + + def whisker_low(series: Series) -> float: + lower_outer_lim = quartile1(series) - 1.5 * ( + quartile3(series) - quartile1(series) + ) + return series[series >= lower_outer_lim].min() + + elif whisker_type == PostProcessingBoxplotWhiskerType.PERCENTILE: + if ( + not isinstance(percentiles, (list, tuple)) + or len(percentiles) != 2 + or not isinstance(percentiles[0], (int, float)) + or not isinstance(percentiles[1], (int, float)) + or percentiles[0] >= percentiles[1] + ): + raise QueryObjectValidationError( + _( + "percentiles must be a list or tuple with two numeric values, " + "of which the first is lower than the second value" + ) + ) + low, high = percentiles[0], percentiles[1] + + def whisker_high(series: Series) -> float: + return np.nanpercentile(series, high) + + def whisker_low(series: Series) -> float: + return np.nanpercentile(series, low) + + else: + whisker_high = np.max + whisker_low = np.min + + def outliers(series: Series) -> Set[float]: + above = series[series > whisker_high(series)] + below = series[series < whisker_low(series)] + return above.tolist() + below.tolist() + + operators: Dict[str, Callable[[Any], Any]] = { + "mean": np.mean, + "median": np.median, + "max": whisker_high, + "min": whisker_low, + "q1": quartile1, + "q3": quartile3, + "count": np.ma.count, + "outliers": outliers, + } + aggregates: Dict[str, Dict[str, Union[str, Callable[..., Any]]]] = { + f"{metric}__{operator_name}": {"column": metric, "operator": operator} + for operator_name, operator in operators.items() + for metric in metrics + } + return aggregate(df, groupby=groupby, aggregates=aggregates) diff --git a/superset/utils/pandas_postprocessing/compare.py b/superset/utils/pandas_postprocessing/compare.py new file mode 100644 index 0000000000000..7c10854a4fb86 --- /dev/null +++ b/superset/utils/pandas_postprocessing/compare.py @@ -0,0 +1,81 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from typing import List, Optional + +import pandas as pd +from flask_babel import gettext as _ +from pandas import DataFrame + +from superset.constants import PandasPostprocessingCompare +from superset.exceptions import QueryObjectValidationError +from superset.utils.core import ( + TIME_COMPARISION, +) +from superset.utils.pandas_postprocessing.utils import validate_column_args + + +@validate_column_args("source_columns", "compare_columns") +def compare( # pylint: disable=too-many-arguments + df: DataFrame, + source_columns: List[str], + compare_columns: List[str], + compare_type: Optional[PandasPostprocessingCompare], + drop_original_columns: Optional[bool] = False, + precision: Optional[int] = 4, +) -> DataFrame: + """ + Calculate column-by-column changing for select columns. + + :param df: DataFrame on which the compare will be based. + :param source_columns: Main query columns + :param compare_columns: Columns being compared + :param compare_type: Type of compare. Choice of `absolute`, `percentage` or `ratio` + :param drop_original_columns: Whether to remove the source columns and + compare columns. + :param precision: Round a change rate to a variable number of decimal places. + :return: DataFrame with compared columns. + :raises QueryObjectValidationError: If the request in incorrect. + """ + if len(source_columns) != len(compare_columns): + raise QueryObjectValidationError( + _("`compare_columns` must have the same length as `source_columns`.") + ) + if compare_type not in tuple(PandasPostprocessingCompare): + raise QueryObjectValidationError( + _("`compare_type` must be `difference`, `percentage` or `ratio`") + ) + if len(source_columns) == 0: + return df + + for s_col, c_col in zip(source_columns, compare_columns): + if compare_type == PandasPostprocessingCompare.DIFF: + diff_series = df[s_col] - df[c_col] + elif compare_type == PandasPostprocessingCompare.PCT: + diff_series = ( + ((df[s_col] - df[c_col]) / df[c_col]).astype(float).round(precision) + ) + else: + # compare_type == "ratio" + diff_series = (df[s_col] / df[c_col]).astype(float).round(precision) + diff_df = diff_series.to_frame( + name=TIME_COMPARISION.join([compare_type, s_col, c_col]) + ) + df = pd.concat([df, diff_df], axis=1) + + if drop_original_columns: + df = df.drop(source_columns + compare_columns, axis=1) + return df diff --git a/superset/utils/pandas_postprocessing/contribution.py b/superset/utils/pandas_postprocessing/contribution.py new file mode 100644 index 0000000000000..1b06b669806f9 --- /dev/null +++ b/superset/utils/pandas_postprocessing/contribution.py @@ -0,0 +1,77 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from decimal import Decimal +from typing import List, Optional + +from flask_babel import gettext as _ +from pandas import DataFrame + +from superset.exceptions import QueryObjectValidationError +from superset.utils.core import ( + PostProcessingContributionOrientation, +) +from superset.utils.pandas_postprocessing.utils import validate_column_args + + +@validate_column_args("columns") +def contribution( + df: DataFrame, + orientation: Optional[ + PostProcessingContributionOrientation + ] = PostProcessingContributionOrientation.COLUMN, + columns: Optional[List[str]] = None, + rename_columns: Optional[List[str]] = None, +) -> DataFrame: + """ + Calculate cell contibution to row/column total for numeric columns. + Non-numeric columns will be kept untouched. + + If `columns` are specified, only calculate contributions on selected columns. + + :param df: DataFrame containing all-numeric data (temporal column ignored) + :param columns: Columns to calculate values from. + :param rename_columns: The new labels for the calculated contribution columns. + The original columns will not be removed. + :param orientation: calculate by dividing cell with row/column total + :return: DataFrame with contributions. + """ + contribution_df = df.copy() + numeric_df = contribution_df.select_dtypes(include=["number", Decimal]) + # verify column selections + if columns: + numeric_columns = numeric_df.columns.tolist() + for col in columns: + if col not in numeric_columns: + raise QueryObjectValidationError( + _( + 'Column "%(column)s" is not numeric or does not ' + "exists in the query results.", + column=col, + ) + ) + columns = columns or numeric_df.columns + rename_columns = rename_columns or columns + if len(rename_columns) != len(columns): + raise QueryObjectValidationError( + _("`rename_columns` must have the same length as `columns`.") + ) + # limit to selected columns + numeric_df = numeric_df[columns] + axis = 0 if orientation == PostProcessingContributionOrientation.COLUMN else 1 + numeric_df = numeric_df / numeric_df.values.sum(axis=axis, keepdims=True) + contribution_df[rename_columns] = numeric_df + return contribution_df diff --git a/superset/utils/pandas_postprocessing/cum.py b/superset/utils/pandas_postprocessing/cum.py new file mode 100644 index 0000000000000..9632a42f90afb --- /dev/null +++ b/superset/utils/pandas_postprocessing/cum.py @@ -0,0 +1,68 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from typing import Any, Dict, Optional + +from flask_babel import gettext as _ +from pandas import DataFrame + +from superset.exceptions import QueryObjectValidationError +from superset.utils.pandas_postprocessing.utils import validate_column_args + + +@validate_column_args("columns") +def cum( + df: DataFrame, + operator: str, + columns: Optional[Dict[str, str]] = None, + is_pivot_df: bool = False, +) -> DataFrame: + """ + Calculate cumulative sum/product/min/max for select columns. + + :param df: DataFrame on which the cumulative operation will be based. + :param columns: columns on which to perform a cumulative operation, mapping source + column to target column. For instance, `{'y': 'y'}` will replace the column + `y` with the cumulative value in `y`, while `{'y': 'y2'}` will add a column + `y2` based on cumulative values calculated from `y`, leaving the original + column `y` unchanged. + :param operator: cumulative operator, e.g. `sum`, `prod`, `min`, `max` + :param is_pivot_df: Dataframe is pivoted or not + :return: DataFrame with cumulated columns + """ + columns = columns or {} + if is_pivot_df: + df_cum = df + else: + df_cum = df[columns.keys()] + operation = "cum" + operator + if operation not in ALLOWLIST_CUMULATIVE_FUNCTIONS or not hasattr( + df_cum, operation + ): + raise QueryObjectValidationError( + _("Invalid cumulative operator: %(operator)s", operator=operator) + ) + if is_pivot_df: + df_cum = getattr(df_cum, operation)() + agg_in_pivot_df = df.columns.get_level_values(0).drop_duplicates().to_list() + agg: Dict[str, Dict[str, Any]] = {col: {} for col in agg_in_pivot_df} + df_cum.columns = [ + _flatten_column_after_pivot(col, agg) for col in df_cum.columns + ] + df_cum.reset_index(level=0, inplace=True) + else: + df_cum = _append_columns(df, getattr(df_cum, operation)(), columns) + return df_cum diff --git a/superset/utils/pandas_postprocessing/diff.py b/superset/utils/pandas_postprocessing/diff.py new file mode 100644 index 0000000000000..866bb59270db1 --- /dev/null +++ b/superset/utils/pandas_postprocessing/diff.py @@ -0,0 +1,41 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +@validate_column_args("columns") +def diff( + df: DataFrame, + columns: Dict[str, str], + periods: int = 1, + axis: PandasAxis = PandasAxis.ROW, +) -> DataFrame: + """ + Calculate row-by-row or column-by-column difference for select columns. + + :param df: DataFrame on which the diff will be based. + :param columns: columns on which to perform diff, mapping source column to + target column. For instance, `{'y': 'y'}` will replace the column `y` with + the diff value in `y`, while `{'y': 'y2'}` will add a column `y2` based + on diff values calculated from `y`, leaving the original column `y` + unchanged. + :param periods: periods to shift for calculating difference. + :param axis: 0 for row, 1 for column. default 0. + :return: DataFrame with diffed columns + :raises QueryObjectValidationError: If the request in incorrect + """ + df_diff = df[columns.keys()] + df_diff = df_diff.diff(periods=periods, axis=axis) + return _append_columns(df, df_diff, columns) diff --git a/superset/utils/pandas_postprocessing/geography.py b/superset/utils/pandas_postprocessing/geography.py new file mode 100644 index 0000000000000..5373836990272 --- /dev/null +++ b/superset/utils/pandas_postprocessing/geography.py @@ -0,0 +1,104 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +def geohash_decode( + df: DataFrame, geohash: str, longitude: str, latitude: str +) -> DataFrame: + """ + Decode a geohash column into longitude and latitude + + :param df: DataFrame containing geohash data + :param geohash: Name of source column containing geohash location. + :param longitude: Name of new column to be created containing longitude. + :param latitude: Name of new column to be created containing latitude. + :return: DataFrame with decoded longitudes and latitudes + """ + try: + lonlat_df = DataFrame() + lonlat_df["latitude"], lonlat_df["longitude"] = zip( + *df[geohash].apply(geohash_lib.decode) + ) + return _append_columns( + df, lonlat_df, {"latitude": latitude, "longitude": longitude} + ) + except ValueError as ex: + raise QueryObjectValidationError(_("Invalid geohash string")) from ex + + +def geohash_encode( + df: DataFrame, geohash: str, longitude: str, latitude: str, +) -> DataFrame: + """ + Encode longitude and latitude into geohash + + :param df: DataFrame containing longitude and latitude data + :param geohash: Name of new column to be created containing geohash location. + :param longitude: Name of source column containing longitude. + :param latitude: Name of source column containing latitude. + :return: DataFrame with decoded longitudes and latitudes + """ + try: + encode_df = df[[latitude, longitude]] + encode_df.columns = ["latitude", "longitude"] + encode_df["geohash"] = encode_df.apply( + lambda row: geohash_lib.encode(row["latitude"], row["longitude"]), axis=1, + ) + return _append_columns(df, encode_df, {"geohash": geohash}) + except ValueError as ex: + raise QueryObjectValidationError(_("Invalid longitude/latitude")) from ex + + +def geodetic_parse( + df: DataFrame, + geodetic: str, + longitude: str, + latitude: str, + altitude: Optional[str] = None, +) -> DataFrame: + """ + Parse a column containing a geodetic point string + [Geopy](https://geopy.readthedocs.io/en/stable/#geopy.point.Point). + + :param df: DataFrame containing geodetic point data + :param geodetic: Name of source column containing geodetic point string. + :param longitude: Name of new column to be created containing longitude. + :param latitude: Name of new column to be created containing latitude. + :param altitude: Name of new column to be created containing altitude. + :return: DataFrame with decoded longitudes and latitudes + """ + + def _parse_location(location: str) -> Tuple[float, float, float]: + """ + Parse a string containing a geodetic point and return latitude, longitude + and altitude + """ + point = Point(location) + return point[0], point[1], point[2] + + try: + geodetic_df = DataFrame() + ( + geodetic_df["latitude"], + geodetic_df["longitude"], + geodetic_df["altitude"], + ) = zip(*df[geodetic].apply(_parse_location)) + columns = {"latitude": latitude, "longitude": longitude} + if altitude: + columns["altitude"] = altitude + return _append_columns(df, geodetic_df, columns) + except ValueError as ex: + raise QueryObjectValidationError(_("Invalid geodetic string")) from ex diff --git a/superset/utils/pandas_postprocessing/pivot.py b/superset/utils/pandas_postprocessing/pivot.py new file mode 100644 index 0000000000000..f14f98812510c --- /dev/null +++ b/superset/utils/pandas_postprocessing/pivot.py @@ -0,0 +1,112 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +@validate_column_args("index", "columns") +def pivot( # pylint: disable=too-many-arguments,too-many-locals + df: DataFrame, + index: List[str], + aggregates: Dict[str, Dict[str, Any]], + columns: Optional[List[str]] = None, + metric_fill_value: Optional[Any] = None, + column_fill_value: Optional[str] = NULL_STRING, + drop_missing_columns: Optional[bool] = True, + combine_value_with_metric: bool = False, + marginal_distributions: Optional[bool] = None, + marginal_distribution_name: Optional[str] = None, + flatten_columns: bool = True, + reset_index: bool = True, +) -> DataFrame: + """ + Perform a pivot operation on a DataFrame. + + :param df: Object on which pivot operation will be performed + :param index: Columns to group by on the table index (=rows) + :param columns: Columns to group by on the table columns + :param metric_fill_value: Value to replace missing values with + :param column_fill_value: Value to replace missing pivot columns with. By default + replaces missing values with "". Set to `None` to remove columns + with missing values. + :param drop_missing_columns: Do not include columns whose entries are all missing + :param combine_value_with_metric: Display metrics side by side within each column, + as opposed to each column being displayed side by side for each metric. + :param aggregates: A mapping from aggregate column name to the the aggregate + config. + :param marginal_distributions: Add totals for row/column. Default to False + :param marginal_distribution_name: Name of row/column with marginal distribution. + Default to 'All'. + :param flatten_columns: Convert column names to strings + :param reset_index: Convert index to column + :return: A pivot table + :raises QueryObjectValidationError: If the request in incorrect + """ + if not index: + raise QueryObjectValidationError( + _("Pivot operation requires at least one index") + ) + if not aggregates: + raise QueryObjectValidationError( + _("Pivot operation must include at least one aggregate") + ) + + if columns and column_fill_value: + df[columns] = df[columns].fillna(value=column_fill_value) + + aggregate_funcs = _get_aggregate_funcs(df, aggregates) + + # TODO (villebro): Pandas 1.0.3 doesn't yet support NamedAgg in pivot_table. + # Remove once/if support is added. + aggfunc = {na.column: na.aggfunc for na in aggregate_funcs.values()} + + # When dropna = False, the pivot_table function will calculate cartesian-product + # for MultiIndex. + # https://github.com/apache/superset/issues/15956 + # https://github.com/pandas-dev/pandas/issues/18030 + series_set = set() + if not drop_missing_columns and columns: + for row in df[columns].itertuples(): + for metric in aggfunc.keys(): + series_set.add(str(tuple([metric]) + tuple(row[1:]))) + + df = df.pivot_table( + values=aggfunc.keys(), + index=index, + columns=columns, + aggfunc=aggfunc, + fill_value=metric_fill_value, + dropna=drop_missing_columns, + margins=marginal_distributions, + margins_name=marginal_distribution_name, + ) + + if not drop_missing_columns and len(series_set) > 0 and not df.empty: + for col in df.columns: + series = str(col) + if series not in series_set: + df = df.drop(col, axis=PandasAxis.COLUMN) + + if combine_value_with_metric: + df = df.stack(0).unstack() + + # Make index regular column + if flatten_columns: + df.columns = [ + _flatten_column_after_pivot(col, aggregates) for col in df.columns + ] + # return index as regular column + if reset_index: + df.reset_index(level=0, inplace=True) + return df diff --git a/superset/utils/pandas_postprocessing/prophet.py b/superset/utils/pandas_postprocessing/prophet.py new file mode 100644 index 0000000000000..e2d9dc0f66ee7 --- /dev/null +++ b/superset/utils/pandas_postprocessing/prophet.py @@ -0,0 +1,147 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +def _prophet_parse_seasonality( + input_value: Optional[Union[bool, int]] +) -> Union[bool, str, int]: + if input_value is None: + return "auto" + if isinstance(input_value, bool): + return input_value + try: + return int(input_value) + except ValueError: + return input_value + + +def _prophet_fit_and_predict( # pylint: disable=too-many-arguments + df: DataFrame, + confidence_interval: float, + yearly_seasonality: Union[bool, str, int], + weekly_seasonality: Union[bool, str, int], + daily_seasonality: Union[bool, str, int], + periods: int, + freq: str, +) -> DataFrame: + """ + Fit a prophet model and return a DataFrame with predicted results. + """ + try: + # pylint: disable=import-error,import-outside-toplevel + from prophet import Prophet + + prophet_logger = logging.getLogger("prophet.plot") + prophet_logger.setLevel(logging.CRITICAL) + prophet_logger.setLevel(logging.NOTSET) + except ModuleNotFoundError as ex: + raise QueryObjectValidationError(_("`prophet` package not installed")) from ex + model = Prophet( + interval_width=confidence_interval, + yearly_seasonality=yearly_seasonality, + weekly_seasonality=weekly_seasonality, + daily_seasonality=daily_seasonality, + ) + if df["ds"].dt.tz: + df["ds"] = df["ds"].dt.tz_convert(None) + model.fit(df) + future = model.make_future_dataframe(periods=periods, freq=freq) + forecast = model.predict(future)[["ds", "yhat", "yhat_lower", "yhat_upper"]] + return forecast.join(df.set_index("ds"), on="ds").set_index(["ds"]) + + +def prophet( # pylint: disable=too-many-arguments + df: DataFrame, + time_grain: str, + periods: int, + confidence_interval: float, + yearly_seasonality: Optional[Union[bool, int]] = None, + weekly_seasonality: Optional[Union[bool, int]] = None, + daily_seasonality: Optional[Union[bool, int]] = None, + index: Optional[str] = None, +) -> DataFrame: + """ + Add forecasts to each series in a timeseries dataframe, along with confidence + intervals for the prediction. For each series, the operation creates three + new columns with the column name suffixed with the following values: + + - `__yhat`: the forecast for the given date + - `__yhat_lower`: the lower bound of the forecast for the given date + - `__yhat_upper`: the upper bound of the forecast for the given date + + + :param df: DataFrame containing all-numeric data (temporal column ignored) + :param time_grain: Time grain used to specify time period increments in prediction + :param periods: Time periods (in units of `time_grain`) to predict into the future + :param confidence_interval: Width of predicted confidence interval + :param yearly_seasonality: Should yearly seasonality be applied. + An integer value will specify Fourier order of seasonality. + :param weekly_seasonality: Should weekly seasonality be applied. + An integer value will specify Fourier order of seasonality, `None` will + automatically detect seasonality. + :param daily_seasonality: Should daily seasonality be applied. + An integer value will specify Fourier order of seasonality, `None` will + automatically detect seasonality. + :param index: the name of the column containing the x-axis data + :return: DataFrame with contributions, with temporal column at beginning if present + """ + index = index or DTTM_ALIAS + # validate inputs + if not time_grain: + raise QueryObjectValidationError(_("Time grain missing")) + if time_grain not in PROPHET_TIME_GRAIN_MAP: + raise QueryObjectValidationError( + _("Unsupported time grain: %(time_grain)s", time_grain=time_grain,) + ) + freq = PROPHET_TIME_GRAIN_MAP[time_grain] + # check type at runtime due to marhsmallow schema not being able to handle + # union types + if not isinstance(periods, int) or periods < 0: + raise QueryObjectValidationError(_("Periods must be a whole number")) + if not confidence_interval or confidence_interval <= 0 or confidence_interval >= 1: + raise QueryObjectValidationError( + _("Confidence interval must be between 0 and 1 (exclusive)") + ) + if index not in df.columns: + raise QueryObjectValidationError(_("DataFrame must include temporal column")) + if len(df.columns) < 2: + raise QueryObjectValidationError(_("DataFrame include at least one series")) + + target_df = DataFrame() + for column in [column for column in df.columns if column != index]: + fit_df = _prophet_fit_and_predict( + df=df[[index, column]].rename(columns={index: "ds", column: "y"}), + confidence_interval=confidence_interval, + yearly_seasonality=_prophet_parse_seasonality(yearly_seasonality), + weekly_seasonality=_prophet_parse_seasonality(weekly_seasonality), + daily_seasonality=_prophet_parse_seasonality(daily_seasonality), + periods=periods, + freq=freq, + ) + new_columns = [ + f"{column}__yhat", + f"{column}__yhat_lower", + f"{column}__yhat_upper", + f"{column}", + ] + fit_df.columns = new_columns + if target_df.empty: + target_df = fit_df + else: + for new_column in new_columns: + target_df = target_df.assign(**{new_column: fit_df[new_column]}) + target_df.reset_index(level=0, inplace=True) + return target_df.rename(columns={"ds": index}) diff --git a/superset/utils/pandas_postprocessing/resample.py b/superset/utils/pandas_postprocessing/resample.py new file mode 100644 index 0000000000000..6e2fa1612b9d7 --- /dev/null +++ b/superset/utils/pandas_postprocessing/resample.py @@ -0,0 +1,55 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +@validate_column_args("groupby_columns") +def resample( # pylint: disable=too-many-arguments + df: DataFrame, + rule: str, + method: str, + time_column: str, + groupby_columns: Optional[Tuple[Optional[str], ...]] = None, + fill_value: Optional[Union[float, int]] = None, +) -> DataFrame: + """ + support upsampling in resample + + :param df: DataFrame to resample. + :param rule: The offset string representing target conversion. + :param method: How to fill the NaN value after resample. + :param time_column: existing columns in DataFrame. + :param groupby_columns: columns except time_column in dataframe + :param fill_value: What values do fill missing. + :return: DataFrame after resample + :raises QueryObjectValidationError: If the request in incorrect + """ + + def _upsampling(_df: DataFrame) -> DataFrame: + _df = _df.set_index(time_column) + if method == "asfreq" and fill_value is not None: + return _df.resample(rule).asfreq(fill_value=fill_value) + return getattr(_df.resample(rule), method)() + + if groupby_columns: + df = ( + df.set_index(keys=list(groupby_columns)) + .groupby(by=list(groupby_columns)) + .apply(_upsampling) + ) + df = df.reset_index().set_index(time_column).sort_index() + else: + df = _upsampling(df) + return df.reset_index() diff --git a/superset/utils/pandas_postprocessing/rolling.py b/superset/utils/pandas_postprocessing/rolling.py new file mode 100644 index 0000000000000..170a6c352d0b9 --- /dev/null +++ b/superset/utils/pandas_postprocessing/rolling.py @@ -0,0 +1,102 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +@validate_column_args("columns") +def rolling( # pylint: disable=too-many-arguments + df: DataFrame, + rolling_type: str, + columns: Optional[Dict[str, str]] = None, + window: Optional[int] = None, + rolling_type_options: Optional[Dict[str, Any]] = None, + center: bool = False, + win_type: Optional[str] = None, + min_periods: Optional[int] = None, + is_pivot_df: bool = False, +) -> DataFrame: + """ + Apply a rolling window on the dataset. See the Pandas docs for further details: + https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.DataFrame.rolling.html + + :param df: DataFrame on which the rolling period will be based. + :param columns: columns on which to perform rolling, mapping source column to + target column. For instance, `{'y': 'y'}` will replace the column `y` with + the rolling value in `y`, while `{'y': 'y2'}` will add a column `y2` based + on rolling values calculated from `y`, leaving the original column `y` + unchanged. + :param rolling_type: Type of rolling window. Any numpy function will work. + :param window: Size of the window. + :param rolling_type_options: Optional options to pass to rolling method. Needed + for e.g. quantile operation. + :param center: Should the label be at the center of the window. + :param win_type: Type of window function. + :param min_periods: The minimum amount of periods required for a row to be included + in the result set. + :param is_pivot_df: Dataframe is pivoted or not + :return: DataFrame with the rolling columns + :raises QueryObjectValidationError: If the request in incorrect + """ + rolling_type_options = rolling_type_options or {} + columns = columns or {} + if is_pivot_df: + df_rolling = df + else: + df_rolling = df[columns.keys()] + kwargs: Dict[str, Union[str, int]] = {} + if window is None: + raise QueryObjectValidationError(_("Undefined window for rolling operation")) + if window == 0: + raise QueryObjectValidationError(_("Window must be > 0")) + + kwargs["window"] = window + if min_periods is not None: + kwargs["min_periods"] = min_periods + if center is not None: + kwargs["center"] = center + if win_type is not None: + kwargs["win_type"] = win_type + + df_rolling = df_rolling.rolling(**kwargs) + if rolling_type not in DENYLIST_ROLLING_FUNCTIONS or not hasattr( + df_rolling, rolling_type + ): + raise QueryObjectValidationError( + _("Invalid rolling_type: %(type)s", type=rolling_type) + ) + try: + df_rolling = getattr(df_rolling, rolling_type)(**rolling_type_options) + except TypeError as ex: + raise QueryObjectValidationError( + _( + "Invalid options for %(rolling_type)s: %(options)s", + rolling_type=rolling_type, + options=rolling_type_options, + ) + ) from ex + + if is_pivot_df: + agg_in_pivot_df = df.columns.get_level_values(0).drop_duplicates().to_list() + agg: Dict[str, Dict[str, Any]] = {col: {} for col in agg_in_pivot_df} + df_rolling.columns = [ + _flatten_column_after_pivot(col, agg) for col in df_rolling.columns + ] + df_rolling.reset_index(level=0, inplace=True) + else: + df_rolling = _append_columns(df, df_rolling, columns) + + if min_periods: + df_rolling = df_rolling[min_periods:] + return df_rolling diff --git a/superset/utils/pandas_postprocessing/select.py b/superset/utils/pandas_postprocessing/select.py new file mode 100644 index 0000000000000..c2de7b65932db --- /dev/null +++ b/superset/utils/pandas_postprocessing/select.py @@ -0,0 +1,48 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +@validate_column_args("columns", "drop", "rename") +def select( + df: DataFrame, + columns: Optional[List[str]] = None, + exclude: Optional[List[str]] = None, + rename: Optional[Dict[str, str]] = None, +) -> DataFrame: + """ + Only select a subset of columns in the original dataset. Can be useful for + removing unnecessary intermediate results, renaming and reordering columns. + + :param df: DataFrame on which the rolling period will be based. + :param columns: Columns which to select from the DataFrame, in the desired order. + If left undefined, all columns will be selected. If columns are + renamed, the original column name should be referenced here. + :param exclude: columns to exclude from selection. If columns are renamed, the new + column name should be referenced here. + :param rename: columns which to rename, mapping source column to target column. + For instance, `{'y': 'y2'}` will rename the column `y` to + `y2`. + :return: Subset of columns in original DataFrame + :raises QueryObjectValidationError: If the request in incorrect + """ + df_select = df.copy(deep=False) + if columns: + df_select = df_select[columns] + if exclude: + df_select = df_select.drop(exclude, axis=1) + if rename is not None: + df_select = df_select.rename(columns=rename) + return df_select diff --git a/superset/utils/pandas_postprocessing/sort.py b/superset/utils/pandas_postprocessing/sort.py new file mode 100644 index 0000000000000..cc2b5f5a03052 --- /dev/null +++ b/superset/utils/pandas_postprocessing/sort.py @@ -0,0 +1,30 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +@validate_column_args("columns") +def sort(df: DataFrame, columns: Dict[str, bool]) -> DataFrame: + """ + Sort a DataFrame. + + :param df: DataFrame to sort. + :param columns: columns by by which to sort. The key specifies the column name, + value specifies if sorting in ascending order. + :return: Sorted DataFrame + :raises QueryObjectValidationError: If the request in incorrect + """ + return df.sort_values(by=list(columns.keys()), ascending=list(columns.values())) + diff --git a/superset/utils/pandas_postprocessing/utils.py b/superset/utils/pandas_postprocessing/utils.py new file mode 100644 index 0000000000000..6cee5b1c7ae8d --- /dev/null +++ b/superset/utils/pandas_postprocessing/utils.py @@ -0,0 +1,215 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import logging +from decimal import Decimal +from functools import partial +from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union + +import geohash as geohash_lib +import numpy as np +import pandas as pd +from flask_babel import gettext as _ +from geopy.point import Point +from pandas import DataFrame, NamedAgg, Series, Timestamp + +from superset.constants import NULL_STRING, PandasAxis, PandasPostprocessingCompare +from superset.exceptions import QueryObjectValidationError +from superset.utils.core import ( + DTTM_ALIAS, + PostProcessingBoxplotWhiskerType, + PostProcessingContributionOrientation, + TIME_COMPARISION, +) + +NUMPY_FUNCTIONS = { + "average": np.average, + "argmin": np.argmin, + "argmax": np.argmax, + "count": np.ma.count, + "count_nonzero": np.count_nonzero, + "cumsum": np.cumsum, + "cumprod": np.cumprod, + "max": np.max, + "mean": np.mean, + "median": np.median, + "nansum": np.nansum, + "nanmin": np.nanmin, + "nanmax": np.nanmax, + "nanmean": np.nanmean, + "nanmedian": np.nanmedian, + "nanpercentile": np.nanpercentile, + "min": np.min, + "percentile": np.percentile, + "prod": np.prod, + "product": np.product, + "std": np.std, + "sum": np.sum, + "var": np.var, +} + +DENYLIST_ROLLING_FUNCTIONS = ( + "count", + "corr", + "cov", + "kurt", + "max", + "mean", + "median", + "min", + "std", + "skew", + "sum", + "var", + "quantile", +) + +ALLOWLIST_CUMULATIVE_FUNCTIONS = ( + "cummax", + "cummin", + "cumprod", + "cumsum", +) + +PROPHET_TIME_GRAIN_MAP = { + "PT1S": "S", + "PT1M": "min", + "PT5M": "5min", + "PT10M": "10min", + "PT15M": "15min", + "PT30M": "30min", + "PT1H": "H", + "P1D": "D", + "P1W": "W", + "P1M": "M", + "P3M": "Q", + "P1Y": "A", + "1969-12-28T00:00:00Z/P1W": "W", + "1969-12-29T00:00:00Z/P1W": "W", + "P1W/1970-01-03T00:00:00Z": "W", + "P1W/1970-01-04T00:00:00Z": "W", +} + + +def _flatten_column_after_pivot( + column: Union[float, Timestamp, str, Tuple[str, ...]], + aggregates: Dict[str, Dict[str, Any]], +) -> str: + """ + Function for flattening column names into a single string. This step is necessary + to be able to properly serialize a DataFrame. If the column is a string, return + element unchanged. For multi-element columns, join column elements with a comma, + with the exception of pivots made with a single aggregate, in which case the + aggregate column name is omitted. + + :param column: single element from `DataFrame.columns` + :param aggregates: aggregates + :return: + """ + if not isinstance(column, tuple): + column = (column,) + if len(aggregates) == 1 and len(column) > 1: + # drop aggregate for single aggregate pivots with multiple groupings + # from column name (aggregates always come first in column name) + column = column[1:] + return ", ".join([str(col) for col in column]) + + +def validate_column_args(*argnames: str) -> Callable[..., Any]: + def wrapper(func: Callable[..., Any]) -> Callable[..., Any]: + def wrapped(df: DataFrame, **options: Any) -> Any: + if options.get("is_pivot_df"): + # skip validation when pivot Dataframe + return func(df, **options) + columns = df.columns.tolist() + for name in argnames: + if name in options and not all( + elem in columns for elem in options.get(name) or [] + ): + raise QueryObjectValidationError( + _("Referenced columns not available in DataFrame.") + ) + return func(df, **options) + + return wrapped + + return wrapper + + +def _get_aggregate_funcs( + df: DataFrame, aggregates: Dict[str, Dict[str, Any]], +) -> Dict[str, NamedAgg]: + """ + Converts a set of aggregate config objects into functions that pandas can use as + aggregators. Currently only numpy aggregators are supported. + + :param df: DataFrame on which to perform aggregate operation. + :param aggregates: Mapping from column name to aggregate config. + :return: Mapping from metric name to function that takes a single input argument. + """ + agg_funcs: Dict[str, NamedAgg] = {} + for name, agg_obj in aggregates.items(): + column = agg_obj.get("column", name) + if column not in df: + raise QueryObjectValidationError( + _( + "Column referenced by aggregate is undefined: %(column)s", + column=column, + ) + ) + if "operator" not in agg_obj: + raise QueryObjectValidationError( + _("Operator undefined for aggregator: %(name)s", name=name,) + ) + operator = agg_obj["operator"] + if callable(operator): + aggfunc = operator + else: + func = NUMPY_FUNCTIONS.get(operator) + if not func: + raise QueryObjectValidationError( + _("Invalid numpy function: %(operator)s", operator=operator,) + ) + options = agg_obj.get("options", {}) + aggfunc = partial(func, **options) + agg_funcs[name] = NamedAgg(column=column, aggfunc=aggfunc) + + return agg_funcs + + +def _append_columns( + base_df: DataFrame, append_df: DataFrame, columns: Dict[str, str] +) -> DataFrame: + """ + Function for adding columns from one DataFrame to another DataFrame. Calls the + assign method, which overwrites the original column in `base_df` if the column + already exists, and appends the column if the name is not defined. + + :param base_df: DataFrame which to use as the base + :param append_df: DataFrame from which to select data. + :param columns: columns on which to append, mapping source column to + target column. For instance, `{'y': 'y'}` will replace the values in + column `y` in `base_df` with the values in `y` in `append_df`, + while `{'y': 'y2'}` will add a column `y2` to `base_df` based + on values in column `y` in `append_df`, leaving the original column `y` + in `base_df` unchanged. + :return: new DataFrame with combined data from `base_df` and `append_df` + """ + return base_df.assign( + **{target: append_df[source] for source, target in columns.items()} + ) + + From 083b289650bfa3ec6c1d0ba2195a48518544d3dd Mon Sep 17 00:00:00 2001 From: Yongjie Zhao Date: Mon, 14 Feb 2022 15:20:05 +0800 Subject: [PATCH 2/5] wip --- superset/utils/pandas_postprocessing/cum.py | 7 ++++++- superset/utils/pandas_postprocessing/diff.py | 10 ++++++++++ .../utils/pandas_postprocessing/geography.py | 10 ++++++++++ superset/utils/pandas_postprocessing/pivot.py | 13 +++++++++++++ .../utils/pandas_postprocessing/prophet.py | 10 ++++++++++ .../utils/pandas_postprocessing/resample.py | 10 ++++++++-- .../utils/pandas_postprocessing/rolling.py | 13 +++++++++++++ superset/utils/pandas_postprocessing/select.py | 6 ++++++ superset/utils/pandas_postprocessing/sort.py | 7 ++++++- superset/utils/pandas_postprocessing/utils.py | 18 ++---------------- 10 files changed, 84 insertions(+), 20 deletions(-) diff --git a/superset/utils/pandas_postprocessing/cum.py b/superset/utils/pandas_postprocessing/cum.py index 9632a42f90afb..c142b36d4ee9f 100644 --- a/superset/utils/pandas_postprocessing/cum.py +++ b/superset/utils/pandas_postprocessing/cum.py @@ -20,7 +20,12 @@ from pandas import DataFrame from superset.exceptions import QueryObjectValidationError -from superset.utils.pandas_postprocessing.utils import validate_column_args +from superset.utils.pandas_postprocessing.utils import ( + _append_columns, + _flatten_column_after_pivot, + ALLOWLIST_CUMULATIVE_FUNCTIONS, + validate_column_args, +) @validate_column_args("columns") diff --git a/superset/utils/pandas_postprocessing/diff.py b/superset/utils/pandas_postprocessing/diff.py index 866bb59270db1..fd7c83bb91f3d 100644 --- a/superset/utils/pandas_postprocessing/diff.py +++ b/superset/utils/pandas_postprocessing/diff.py @@ -14,6 +14,16 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from typing import Dict + +from pandas import DataFrame + +from superset.constants import PandasAxis +from superset.utils.pandas_postprocessing.utils import ( + _append_columns, + validate_column_args, +) + @validate_column_args("columns") def diff( diff --git a/superset/utils/pandas_postprocessing/geography.py b/superset/utils/pandas_postprocessing/geography.py index 5373836990272..a1aae59e395a3 100644 --- a/superset/utils/pandas_postprocessing/geography.py +++ b/superset/utils/pandas_postprocessing/geography.py @@ -14,6 +14,16 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from typing import Optional, Tuple + +import geohash as geohash_lib +from flask_babel import gettext as _ +from geopy.point import Point +from pandas import DataFrame + +from superset.exceptions import QueryObjectValidationError +from superset.utils.pandas_postprocessing.utils import _append_columns + def geohash_decode( df: DataFrame, geohash: str, longitude: str, latitude: str diff --git a/superset/utils/pandas_postprocessing/pivot.py b/superset/utils/pandas_postprocessing/pivot.py index f14f98812510c..b9d70e9087cc1 100644 --- a/superset/utils/pandas_postprocessing/pivot.py +++ b/superset/utils/pandas_postprocessing/pivot.py @@ -14,6 +14,19 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from typing import Any, Dict, List, Optional + +from flask_babel import gettext as _ +from pandas import DataFrame + +from superset.constants import NULL_STRING, PandasAxis +from superset.exceptions import QueryObjectValidationError +from superset.utils.pandas_postprocessing.utils import ( + _flatten_column_after_pivot, + _get_aggregate_funcs, + validate_column_args, +) + @validate_column_args("index", "columns") def pivot( # pylint: disable=too-many-arguments,too-many-locals diff --git a/superset/utils/pandas_postprocessing/prophet.py b/superset/utils/pandas_postprocessing/prophet.py index e2d9dc0f66ee7..3ade7f675d426 100644 --- a/superset/utils/pandas_postprocessing/prophet.py +++ b/superset/utils/pandas_postprocessing/prophet.py @@ -14,6 +14,16 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +import logging +from typing import Optional, Union + +from flask_babel import gettext as _ +from pandas import DataFrame + +from superset.exceptions import QueryObjectValidationError +from superset.utils.core import DTTM_ALIAS +from superset.utils.pandas_postprocessing.utils import PROPHET_TIME_GRAIN_MAP + def _prophet_parse_seasonality( input_value: Optional[Union[bool, int]] diff --git a/superset/utils/pandas_postprocessing/resample.py b/superset/utils/pandas_postprocessing/resample.py index 6e2fa1612b9d7..54e67ac009ee2 100644 --- a/superset/utils/pandas_postprocessing/resample.py +++ b/superset/utils/pandas_postprocessing/resample.py @@ -14,6 +14,12 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from typing import Optional, Tuple, Union + +from pandas import DataFrame + +from superset.utils.pandas_postprocessing.utils import validate_column_args + @validate_column_args("groupby_columns") def resample( # pylint: disable=too-many-arguments @@ -46,8 +52,8 @@ def _upsampling(_df: DataFrame) -> DataFrame: if groupby_columns: df = ( df.set_index(keys=list(groupby_columns)) - .groupby(by=list(groupby_columns)) - .apply(_upsampling) + .groupby(by=list(groupby_columns)) + .apply(_upsampling) ) df = df.reset_index().set_index(time_column).sort_index() else: diff --git a/superset/utils/pandas_postprocessing/rolling.py b/superset/utils/pandas_postprocessing/rolling.py index 170a6c352d0b9..f93b3da851749 100644 --- a/superset/utils/pandas_postprocessing/rolling.py +++ b/superset/utils/pandas_postprocessing/rolling.py @@ -14,6 +14,19 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from typing import Any, Dict, Optional, Union + +from flask_babel import gettext as _ +from pandas import DataFrame + +from superset.exceptions import QueryObjectValidationError +from superset.utils.pandas_postprocessing.utils import ( + _append_columns, + _flatten_column_after_pivot, + DENYLIST_ROLLING_FUNCTIONS, + validate_column_args, +) + @validate_column_args("columns") def rolling( # pylint: disable=too-many-arguments diff --git a/superset/utils/pandas_postprocessing/select.py b/superset/utils/pandas_postprocessing/select.py index c2de7b65932db..209d50255fc41 100644 --- a/superset/utils/pandas_postprocessing/select.py +++ b/superset/utils/pandas_postprocessing/select.py @@ -14,6 +14,12 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from typing import Dict, List, Optional + +from pandas import DataFrame + +from superset.utils.pandas_postprocessing.utils import validate_column_args + @validate_column_args("columns", "drop", "rename") def select( diff --git a/superset/utils/pandas_postprocessing/sort.py b/superset/utils/pandas_postprocessing/sort.py index cc2b5f5a03052..fdf8f94c355b0 100644 --- a/superset/utils/pandas_postprocessing/sort.py +++ b/superset/utils/pandas_postprocessing/sort.py @@ -14,6 +14,12 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from typing import Dict + +from pandas import DataFrame + +from superset.utils.pandas_postprocessing.utils import validate_column_args + @validate_column_args("columns") def sort(df: DataFrame, columns: Dict[str, bool]) -> DataFrame: @@ -27,4 +33,3 @@ def sort(df: DataFrame, columns: Dict[str, bool]) -> DataFrame: :raises QueryObjectValidationError: If the request in incorrect """ return df.sort_values(by=list(columns.keys()), ascending=list(columns.values())) - diff --git a/superset/utils/pandas_postprocessing/utils.py b/superset/utils/pandas_postprocessing/utils.py index 6cee5b1c7ae8d..7d2699457f2d9 100644 --- a/superset/utils/pandas_postprocessing/utils.py +++ b/superset/utils/pandas_postprocessing/utils.py @@ -14,26 +14,14 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -import logging -from decimal import Decimal from functools import partial -from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union +from typing import Any, Callable, Dict, Tuple, Union -import geohash as geohash_lib import numpy as np -import pandas as pd from flask_babel import gettext as _ -from geopy.point import Point -from pandas import DataFrame, NamedAgg, Series, Timestamp +from pandas import DataFrame, NamedAgg, Timestamp -from superset.constants import NULL_STRING, PandasAxis, PandasPostprocessingCompare from superset.exceptions import QueryObjectValidationError -from superset.utils.core import ( - DTTM_ALIAS, - PostProcessingBoxplotWhiskerType, - PostProcessingContributionOrientation, - TIME_COMPARISION, -) NUMPY_FUNCTIONS = { "average": np.average, @@ -211,5 +199,3 @@ def _append_columns( return base_df.assign( **{target: append_df[source] for source, target in columns.items()} ) - - From 0cf1b5b2363ae492be6347b8be4efb221712d619 Mon Sep 17 00:00:00 2001 From: Yongjie Zhao Date: Mon, 14 Feb 2022 15:35:35 +0800 Subject: [PATCH 3/5] format --- .../utils/pandas_postprocessing/__init__.py | 36 +++++++++++++++++++ 1 file changed, 36 insertions(+) diff --git a/superset/utils/pandas_postprocessing/__init__.py b/superset/utils/pandas_postprocessing/__init__.py index 245692337bc3f..976e629b7a6d5 100644 --- a/superset/utils/pandas_postprocessing/__init__.py +++ b/superset/utils/pandas_postprocessing/__init__.py @@ -14,4 +14,40 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from superset.utils.pandas_postprocessing.aggregate import aggregate +from superset.utils.pandas_postprocessing.boxplot import boxplot +from superset.utils.pandas_postprocessing.compare import compare +from superset.utils.pandas_postprocessing.contribution import contribution +from superset.utils.pandas_postprocessing.cum import cum +from superset.utils.pandas_postprocessing.diff import diff +from superset.utils.pandas_postprocessing.geography import ( + geodetic_parse, + geohash_decode, + geohash_encode, +) +from superset.utils.pandas_postprocessing.pivot import pivot +from superset.utils.pandas_postprocessing.prophet import prophet +from superset.utils.pandas_postprocessing.resample import resample +from superset.utils.pandas_postprocessing.rolling import rolling +from superset.utils.pandas_postprocessing.select import select +from superset.utils.pandas_postprocessing.sort import sort +from superset.utils.pandas_postprocessing.utils import _flatten_column_after_pivot +__all__ = [ + "aggregate", + "boxplot", + "compare", + "contribution", + "cum", + "diff", + "geohash_encode", + "geohash_decode", + "geodetic_parse", + "pivot", + "prophet", + "resample", + "rolling", + "select", + "sort", + "_flatten_column_after_pivot", +] From ea73615e4295b290def6e6cc246322bdeaa4a6f3 Mon Sep 17 00:00:00 2001 From: Yongjie Zhao Date: Mon, 14 Feb 2022 15:43:18 +0800 Subject: [PATCH 4/5] pylint --- superset/utils/pandas_postprocessing/aggregate.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/superset/utils/pandas_postprocessing/aggregate.py b/superset/utils/pandas_postprocessing/aggregate.py index 5dde70f998e2b..2d6d396e3112f 100644 --- a/superset/utils/pandas_postprocessing/aggregate.py +++ b/superset/utils/pandas_postprocessing/aggregate.py @@ -15,11 +15,12 @@ # specific language governing permissions and limitations # under the License. from typing import Any, Dict, List + from pandas import DataFrame from superset.utils.pandas_postprocessing.utils import ( - validate_column_args, _get_aggregate_funcs, + validate_column_args, ) @@ -43,5 +44,3 @@ def aggregate( else: df_groupby = df.groupby(lambda _: True) return df_groupby.agg(**aggregate_funcs).reset_index(drop=not groupby) - - From 1df183e7b64aff4da182ee46feeee4b77dbafac9 Mon Sep 17 00:00:00 2001 From: Yongjie Zhao Date: Mon, 14 Feb 2022 16:00:38 +0800 Subject: [PATCH 5/5] isort --- superset/utils/pandas_postprocessing/boxplot.py | 4 +--- superset/utils/pandas_postprocessing/compare.py | 4 +--- superset/utils/pandas_postprocessing/contribution.py | 4 +--- 3 files changed, 3 insertions(+), 9 deletions(-) diff --git a/superset/utils/pandas_postprocessing/boxplot.py b/superset/utils/pandas_postprocessing/boxplot.py index ce00ed479c762..9887507d1bc0f 100644 --- a/superset/utils/pandas_postprocessing/boxplot.py +++ b/superset/utils/pandas_postprocessing/boxplot.py @@ -21,9 +21,7 @@ from pandas import DataFrame, Series from superset.exceptions import QueryObjectValidationError -from superset.utils.core import ( - PostProcessingBoxplotWhiskerType, -) +from superset.utils.core import PostProcessingBoxplotWhiskerType from superset.utils.pandas_postprocessing.aggregate import aggregate diff --git a/superset/utils/pandas_postprocessing/compare.py b/superset/utils/pandas_postprocessing/compare.py index 7c10854a4fb86..67f275e659843 100644 --- a/superset/utils/pandas_postprocessing/compare.py +++ b/superset/utils/pandas_postprocessing/compare.py @@ -22,9 +22,7 @@ from superset.constants import PandasPostprocessingCompare from superset.exceptions import QueryObjectValidationError -from superset.utils.core import ( - TIME_COMPARISION, -) +from superset.utils.core import TIME_COMPARISION from superset.utils.pandas_postprocessing.utils import validate_column_args diff --git a/superset/utils/pandas_postprocessing/contribution.py b/superset/utils/pandas_postprocessing/contribution.py index 1b06b669806f9..d813961fb5f55 100644 --- a/superset/utils/pandas_postprocessing/contribution.py +++ b/superset/utils/pandas_postprocessing/contribution.py @@ -21,9 +21,7 @@ from pandas import DataFrame from superset.exceptions import QueryObjectValidationError -from superset.utils.core import ( - PostProcessingContributionOrientation, -) +from superset.utils.core import PostProcessingContributionOrientation from superset.utils.pandas_postprocessing.utils import validate_column_args