diff --git a/x-pack/plugins/ml/common/constants/data_frame_analytics.ts b/x-pack/plugins/ml/common/constants/data_frame_analytics.ts new file mode 100644 index 0000000000000..830537cbadbc8 --- /dev/null +++ b/x-pack/plugins/ml/common/constants/data_frame_analytics.ts @@ -0,0 +1,7 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ + +export const DEFAULT_RESULTS_FIELD = 'ml'; diff --git a/x-pack/plugins/ml/common/types/data_frame_analytics.ts b/x-pack/plugins/ml/common/types/data_frame_analytics.ts index f0aac75047585..60d2ca63dda59 100644 --- a/x-pack/plugins/ml/common/types/data_frame_analytics.ts +++ b/x-pack/plugins/ml/common/types/data_frame_analytics.ts @@ -79,3 +79,9 @@ export interface DataFrameAnalyticsConfig { version: string; allow_lazy_start?: boolean; } + +export enum ANALYSIS_CONFIG_TYPE { + OUTLIER_DETECTION = 'outlier_detection', + REGRESSION = 'regression', + CLASSIFICATION = 'classification', +} diff --git a/x-pack/plugins/ml/common/types/feature_importance.ts b/x-pack/plugins/ml/common/types/feature_importance.ts new file mode 100644 index 0000000000000..d2ab9f6c58608 --- /dev/null +++ b/x-pack/plugins/ml/common/types/feature_importance.ts @@ -0,0 +1,23 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ + +export interface ClassFeatureImportance { + class_name: string | boolean; + importance: number; +} +export interface FeatureImportance { + feature_name: string; + importance?: number; + classes?: ClassFeatureImportance[]; +} + +export interface TopClass { + class_name: string; + class_probability: number; + class_score: number; +} + +export type TopClasses = TopClass[]; diff --git a/x-pack/plugins/ml/common/util/analytics_utils.ts b/x-pack/plugins/ml/common/util/analytics_utils.ts new file mode 100644 index 0000000000000..d725984a47d66 --- /dev/null +++ b/x-pack/plugins/ml/common/util/analytics_utils.ts @@ -0,0 +1,79 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ + +import { + AnalysisConfig, + ClassificationAnalysis, + OutlierAnalysis, + RegressionAnalysis, + ANALYSIS_CONFIG_TYPE, +} from '../types/data_frame_analytics'; + +export const isOutlierAnalysis = (arg: any): arg is OutlierAnalysis => { + const keys = Object.keys(arg); + return keys.length === 1 && keys[0] === ANALYSIS_CONFIG_TYPE.OUTLIER_DETECTION; +}; + +export const isRegressionAnalysis = (arg: any): arg is RegressionAnalysis => { + const keys = Object.keys(arg); + return keys.length === 1 && keys[0] === ANALYSIS_CONFIG_TYPE.REGRESSION; +}; + +export const isClassificationAnalysis = (arg: any): arg is ClassificationAnalysis => { + const keys = Object.keys(arg); + return keys.length === 1 && keys[0] === ANALYSIS_CONFIG_TYPE.CLASSIFICATION; +}; + +export const getDependentVar = ( + analysis: AnalysisConfig +): + | RegressionAnalysis['regression']['dependent_variable'] + | ClassificationAnalysis['classification']['dependent_variable'] => { + let depVar = ''; + + if (isRegressionAnalysis(analysis)) { + depVar = analysis.regression.dependent_variable; + } + + if (isClassificationAnalysis(analysis)) { + depVar = analysis.classification.dependent_variable; + } + return depVar; +}; + +export const getPredictionFieldName = ( + analysis: AnalysisConfig +): + | RegressionAnalysis['regression']['prediction_field_name'] + | ClassificationAnalysis['classification']['prediction_field_name'] => { + // If undefined will be defaulted to dependent_variable when config is created + let predictionFieldName; + if (isRegressionAnalysis(analysis) && analysis.regression.prediction_field_name !== undefined) { + predictionFieldName = analysis.regression.prediction_field_name; + } else if ( + isClassificationAnalysis(analysis) && + analysis.classification.prediction_field_name !== undefined + ) { + predictionFieldName = analysis.classification.prediction_field_name; + } + return predictionFieldName; +}; + +export const getDefaultPredictionFieldName = (analysis: AnalysisConfig) => { + return `${getDependentVar(analysis)}_prediction`; +}; +export const getPredictedFieldName = ( + resultsField: string, + analysis: AnalysisConfig, + forSort?: boolean +) => { + // default is 'ml' + const predictionFieldName = getPredictionFieldName(analysis); + const predictedField = `${resultsField}.${ + predictionFieldName ? predictionFieldName : getDefaultPredictionFieldName(analysis) + }`; + return predictedField; +}; diff --git a/x-pack/plugins/ml/public/application/components/data_grid/common.ts b/x-pack/plugins/ml/public/application/components/data_grid/common.ts index 1f0fcb63f019d..f252729cc20cd 100644 --- a/x-pack/plugins/ml/public/application/components/data_grid/common.ts +++ b/x-pack/plugins/ml/public/application/components/data_grid/common.ts @@ -119,13 +119,14 @@ export const getDataGridSchemasFromFieldTypes = (fieldTypes: FieldTypes, results schema = 'numeric'; } - if ( - field.includes(`${resultsField}.${FEATURE_IMPORTANCE}`) || - field.includes(`${resultsField}.${TOP_CLASSES}`) - ) { + if (field.includes(`${resultsField}.${TOP_CLASSES}`)) { schema = 'json'; } + if (field.includes(`${resultsField}.${FEATURE_IMPORTANCE}`)) { + schema = 'featureImportance'; + } + return { id: field, schema, isSortable }; }); }; @@ -250,10 +251,6 @@ export const useRenderCellValue = ( return cellValue ? 'true' : 'false'; } - if (typeof cellValue === 'object' && cellValue !== null) { - return JSON.stringify(cellValue); - } - return cellValue; }; }, [indexPattern?.fields, pagination.pageIndex, pagination.pageSize, tableItems]); diff --git a/x-pack/plugins/ml/public/application/components/data_grid/data_grid.tsx b/x-pack/plugins/ml/public/application/components/data_grid/data_grid.tsx index d4be2eab13d26..22815fe593d57 100644 --- a/x-pack/plugins/ml/public/application/components/data_grid/data_grid.tsx +++ b/x-pack/plugins/ml/public/application/components/data_grid/data_grid.tsx @@ -5,8 +5,7 @@ */ import { isEqual } from 'lodash'; -import React, { memo, useEffect, FC } from 'react'; - +import React, { memo, useEffect, FC, useMemo } from 'react'; import { i18n } from '@kbn/i18n'; import { @@ -24,13 +23,16 @@ import { } from '@elastic/eui'; import { CoreSetup } from 'src/core/public'; - import { DEFAULT_SAMPLER_SHARD_SIZE } from '../../../../common/constants/field_histograms'; -import { INDEX_STATUS } from '../../data_frame_analytics/common'; +import { ANALYSIS_CONFIG_TYPE, INDEX_STATUS } from '../../data_frame_analytics/common'; import { euiDataGridStyle, euiDataGridToolbarSettings } from './common'; import { UseIndexDataReturnType } from './types'; +import { DecisionPathPopover } from './feature_importance/decision_path_popover'; +import { TopClasses } from '../../../../common/types/feature_importance'; +import { DEFAULT_RESULTS_FIELD } from '../../../../common/constants/data_frame_analytics'; + // TODO Fix row hovering + bar highlighting // import { hoveredRow$ } from './column_chart'; @@ -41,6 +43,9 @@ export const DataGridTitle: FC<{ title: string }> = ({ title }) => ( ); interface PropsWithoutHeader extends UseIndexDataReturnType { + baseline?: number; + analysisType?: ANALYSIS_CONFIG_TYPE; + resultsField?: string; dataTestSubj: string; toastNotifications: CoreSetup['notifications']['toasts']; } @@ -60,6 +65,7 @@ type Props = PropsWithHeader | PropsWithoutHeader; export const DataGrid: FC = memo( (props) => { const { + baseline, chartsVisible, chartsButtonVisible, columnsWithCharts, @@ -80,8 +86,10 @@ export const DataGrid: FC = memo( toastNotifications, toggleChartVisibility, visibleColumns, + predictionFieldName, + resultsField, + analysisType, } = props; - // TODO Fix row hovering + bar highlighting // const getRowProps = (item: any) => { // return { @@ -90,6 +98,45 @@ export const DataGrid: FC = memo( // }; // }; + const popOverContent = useMemo(() => { + return analysisType === ANALYSIS_CONFIG_TYPE.REGRESSION || + analysisType === ANALYSIS_CONFIG_TYPE.CLASSIFICATION + ? { + featureImportance: ({ children }: { cellContentsElement: any; children: any }) => { + const rowIndex = children?.props?.visibleRowIndex; + const row = data[rowIndex]; + if (!row) return
; + // if resultsField for some reason is not available then use ml + const mlResultsField = resultsField ?? DEFAULT_RESULTS_FIELD; + const parsedFIArray = row[mlResultsField].feature_importance; + let predictedValue: string | number | undefined; + let topClasses: TopClasses = []; + if ( + predictionFieldName !== undefined && + row && + row[mlResultsField][predictionFieldName] !== undefined + ) { + predictedValue = row[mlResultsField][predictionFieldName]; + topClasses = row[mlResultsField].top_classes; + } + + return ( + + ); + }, + } + : undefined; + }, [baseline, data]); + useEffect(() => { if (invalidSortingColumnns.length > 0) { invalidSortingColumnns.forEach((columnId) => { @@ -225,6 +272,7 @@ export const DataGrid: FC = memo( } : {}), }} + popoverContents={popOverContent} pagination={{ ...pagination, pageSizeOptions: [5, 10, 25], diff --git a/x-pack/plugins/ml/public/application/components/data_grid/feature_importance/decision_path_chart.tsx b/x-pack/plugins/ml/public/application/components/data_grid/feature_importance/decision_path_chart.tsx new file mode 100644 index 0000000000000..b546ac1db57dd --- /dev/null +++ b/x-pack/plugins/ml/public/application/components/data_grid/feature_importance/decision_path_chart.tsx @@ -0,0 +1,166 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ + +import { + AnnotationDomainTypes, + Axis, + AxisStyle, + Chart, + LineAnnotation, + LineAnnotationStyle, + LineAnnotationDatum, + LineSeries, + PartialTheme, + Position, + RecursivePartial, + ScaleType, + Settings, +} from '@elastic/charts'; +import { EuiIcon } from '@elastic/eui'; + +import React, { useCallback, useMemo } from 'react'; +import { i18n } from '@kbn/i18n'; +import euiVars from '@elastic/eui/dist/eui_theme_light.json'; +import { DecisionPathPlotData } from './use_classification_path_data'; + +const { euiColorFullShade, euiColorMediumShade } = euiVars; +const axisColor = euiColorMediumShade; + +const baselineStyle: LineAnnotationStyle = { + line: { + strokeWidth: 1, + stroke: euiColorFullShade, + opacity: 0.75, + }, + details: { + fontFamily: 'Arial', + fontSize: 10, + fontStyle: 'bold', + fill: euiColorMediumShade, + padding: 0, + }, +}; + +const axes: RecursivePartial = { + axisLine: { + stroke: axisColor, + }, + tickLabel: { + fontSize: 10, + fill: axisColor, + }, + tickLine: { + stroke: axisColor, + }, + gridLine: { + horizontal: { + dash: [1, 2], + }, + vertical: { + strokeWidth: 0, + }, + }, +}; +const theme: PartialTheme = { + axes, +}; + +interface DecisionPathChartProps { + decisionPathData: DecisionPathPlotData; + predictionFieldName?: string; + baseline?: number; + minDomain: number | undefined; + maxDomain: number | undefined; +} + +const DECISION_PATH_MARGIN = 125; +const DECISION_PATH_ROW_HEIGHT = 10; +const NUM_PRECISION = 3; +const AnnotationBaselineMarker = ; + +export const DecisionPathChart = ({ + decisionPathData, + predictionFieldName, + minDomain, + maxDomain, + baseline, +}: DecisionPathChartProps) => { + // adjust the height so it's compact for items with more features + const baselineData: LineAnnotationDatum[] = useMemo( + () => [ + { + dataValue: baseline, + header: baseline ? baseline.toPrecision(NUM_PRECISION) : '', + details: i18n.translate( + 'xpack.ml.dataframe.analytics.explorationResults.decisionPathBaselineText', + { + defaultMessage: + 'baseline (average of predictions for all data points in the training data set)', + } + ), + }, + ], + [baseline] + ); + // guarantee up to num_precision significant digits + // without having it in scientific notation + const tickFormatter = useCallback((d) => Number(d.toPrecision(NUM_PRECISION)).toString(), []); + + return ( + + + {baseline && ( + + )} + + + + + + ); +}; diff --git a/x-pack/plugins/ml/public/application/components/data_grid/feature_importance/decision_path_classification.tsx b/x-pack/plugins/ml/public/application/components/data_grid/feature_importance/decision_path_classification.tsx new file mode 100644 index 0000000000000..bd001fa81a582 --- /dev/null +++ b/x-pack/plugins/ml/public/application/components/data_grid/feature_importance/decision_path_classification.tsx @@ -0,0 +1,105 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ + +import React, { FC, useMemo, useState } from 'react'; +import { i18n } from '@kbn/i18n'; +import { EuiHealth, EuiSpacer, EuiSuperSelect, EuiTitle } from '@elastic/eui'; +import d3 from 'd3'; +import { + isDecisionPathData, + useDecisionPathData, + getStringBasedClassName, +} from './use_classification_path_data'; +import { FeatureImportance, TopClasses } from '../../../../../common/types/feature_importance'; +import { DecisionPathChart } from './decision_path_chart'; +import { MissingDecisionPathCallout } from './missing_decision_path_callout'; + +interface ClassificationDecisionPathProps { + predictedValue: string | boolean; + predictionFieldName?: string; + featureImportance: FeatureImportance[]; + topClasses: TopClasses; +} + +export const ClassificationDecisionPath: FC = ({ + featureImportance, + predictedValue, + topClasses, + predictionFieldName, +}) => { + const [currentClass, setCurrentClass] = useState( + getStringBasedClassName(topClasses[0].class_name) + ); + const { decisionPathData } = useDecisionPathData({ + featureImportance, + predictedValue: currentClass, + }); + const options = useMemo(() => { + const predictionValueStr = getStringBasedClassName(predictedValue); + + return Array.isArray(topClasses) + ? topClasses.map((c) => { + const className = getStringBasedClassName(c.class_name); + return { + value: className, + inputDisplay: + className === predictionValueStr ? ( + + {className} + + ) : ( + className + ), + }; + }) + : undefined; + }, [topClasses, predictedValue]); + + const domain = useMemo(() => { + let maxDomain; + let minDomain; + // if decisionPathData has calculated cumulative path + if (decisionPathData && isDecisionPathData(decisionPathData)) { + const [min, max] = d3.extent(decisionPathData, (d: [string, number, number]) => d[2]); + const buffer = Math.abs(max - min) * 0.1; + maxDomain = max + buffer; + minDomain = min - buffer; + } + return { maxDomain, minDomain }; + }, [decisionPathData]); + + if (!decisionPathData) return ; + + return ( + <> + + + + {i18n.translate( + 'xpack.ml.dataframe.analytics.explorationResults.classificationDecisionPathClassNameTitle', + { + defaultMessage: 'Class name', + } + )} + + + {options !== undefined && ( + + )} + + + ); +}; diff --git a/x-pack/plugins/ml/public/application/components/data_grid/feature_importance/decision_path_json_viewer.tsx b/x-pack/plugins/ml/public/application/components/data_grid/feature_importance/decision_path_json_viewer.tsx new file mode 100644 index 0000000000000..343324b27f9b5 --- /dev/null +++ b/x-pack/plugins/ml/public/application/components/data_grid/feature_importance/decision_path_json_viewer.tsx @@ -0,0 +1,16 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ + +import React, { FC } from 'react'; +import { EuiCodeBlock } from '@elastic/eui'; +import { FeatureImportance } from '../../../../../common/types/feature_importance'; + +interface DecisionPathJSONViewerProps { + featureImportance: FeatureImportance[]; +} +export const DecisionPathJSONViewer: FC = ({ featureImportance }) => { + return {JSON.stringify(featureImportance)}; +}; diff --git a/x-pack/plugins/ml/public/application/components/data_grid/feature_importance/decision_path_popover.tsx b/x-pack/plugins/ml/public/application/components/data_grid/feature_importance/decision_path_popover.tsx new file mode 100644 index 0000000000000..263337f93e9a8 --- /dev/null +++ b/x-pack/plugins/ml/public/application/components/data_grid/feature_importance/decision_path_popover.tsx @@ -0,0 +1,134 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ + +import React, { FC, useState } from 'react'; +import { EuiLink, EuiTab, EuiTabs, EuiText } from '@elastic/eui'; +import { FormattedMessage } from '@kbn/i18n/react'; +import { RegressionDecisionPath } from './decision_path_regression'; +import { DecisionPathJSONViewer } from './decision_path_json_viewer'; +import { FeatureImportance, TopClasses } from '../../../../../common/types/feature_importance'; +import { ANALYSIS_CONFIG_TYPE } from '../../../data_frame_analytics/common'; +import { ClassificationDecisionPath } from './decision_path_classification'; +import { useMlKibana } from '../../../contexts/kibana'; + +interface DecisionPathPopoverProps { + featureImportance: FeatureImportance[]; + analysisType: ANALYSIS_CONFIG_TYPE; + predictionFieldName?: string; + baseline?: number; + predictedValue?: number | string | undefined; + topClasses?: TopClasses; +} + +enum DECISION_PATH_TABS { + CHART = 'decision_path_chart', + JSON = 'decision_path_json', +} + +export interface ExtendedFeatureImportance extends FeatureImportance { + absImportance?: number; +} + +export const DecisionPathPopover: FC = ({ + baseline, + featureImportance, + predictedValue, + topClasses, + analysisType, + predictionFieldName, +}) => { + const [selectedTabId, setSelectedTabId] = useState(DECISION_PATH_TABS.CHART); + const { + services: { docLinks }, + } = useMlKibana(); + const { ELASTIC_WEBSITE_URL, DOC_LINK_VERSION } = docLinks; + + if (featureImportance.length < 2) { + return ; + } + + const tabs = [ + { + id: DECISION_PATH_TABS.CHART, + name: ( + + ), + }, + { + id: DECISION_PATH_TABS.JSON, + name: ( + + ), + }, + ]; + + return ( + <> +
+ + {tabs.map((tab) => ( + setSelectedTabId(tab.id)} + key={tab.id} + > + {tab.name} + + ))} + +
+ {selectedTabId === DECISION_PATH_TABS.CHART && ( + <> + + + + + ), + }} + /> + + {analysisType === ANALYSIS_CONFIG_TYPE.CLASSIFICATION && ( + + )} + {analysisType === ANALYSIS_CONFIG_TYPE.REGRESSION && ( + + )} + + )} + {selectedTabId === DECISION_PATH_TABS.JSON && ( + + )} + + ); +}; diff --git a/x-pack/plugins/ml/public/application/components/data_grid/feature_importance/decision_path_regression.tsx b/x-pack/plugins/ml/public/application/components/data_grid/feature_importance/decision_path_regression.tsx new file mode 100644 index 0000000000000..345269a944f02 --- /dev/null +++ b/x-pack/plugins/ml/public/application/components/data_grid/feature_importance/decision_path_regression.tsx @@ -0,0 +1,79 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ + +import React, { FC, useMemo } from 'react'; +import { EuiCallOut } from '@elastic/eui'; +import { FormattedMessage } from '@kbn/i18n/react'; +import d3 from 'd3'; +import { FeatureImportance, TopClasses } from '../../../../../common/types/feature_importance'; +import { useDecisionPathData, isDecisionPathData } from './use_classification_path_data'; +import { DecisionPathChart } from './decision_path_chart'; +import { MissingDecisionPathCallout } from './missing_decision_path_callout'; + +interface RegressionDecisionPathProps { + predictionFieldName?: string; + baseline?: number; + predictedValue?: number | undefined; + featureImportance: FeatureImportance[]; + topClasses?: TopClasses; +} + +export const RegressionDecisionPath: FC = ({ + baseline, + featureImportance, + predictedValue, + predictionFieldName, +}) => { + const { decisionPathData } = useDecisionPathData({ + baseline, + featureImportance, + predictedValue, + }); + const domain = useMemo(() => { + let maxDomain; + let minDomain; + // if decisionPathData has calculated cumulative path + if (decisionPathData && isDecisionPathData(decisionPathData)) { + const [min, max] = d3.extent(decisionPathData, (d: [string, number, number]) => d[2]); + maxDomain = max; + minDomain = min; + const buffer = Math.abs(maxDomain - minDomain) * 0.1; + maxDomain = + (typeof baseline === 'number' ? Math.max(maxDomain, baseline) : maxDomain) + buffer; + minDomain = + (typeof baseline === 'number' ? Math.min(minDomain, baseline) : minDomain) - buffer; + } + return { maxDomain, minDomain }; + }, [decisionPathData, baseline]); + + if (!decisionPathData) return ; + + return ( + <> + {baseline === undefined && ( + + } + color="warning" + iconType="alert" + /> + )} + + + ); +}; diff --git a/x-pack/plugins/ml/public/application/components/data_grid/feature_importance/missing_decision_path_callout.tsx b/x-pack/plugins/ml/public/application/components/data_grid/feature_importance/missing_decision_path_callout.tsx new file mode 100644 index 0000000000000..66eb2047b1314 --- /dev/null +++ b/x-pack/plugins/ml/public/application/components/data_grid/feature_importance/missing_decision_path_callout.tsx @@ -0,0 +1,20 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ + +import React from 'react'; +import { EuiCallOut } from '@elastic/eui'; +import { FormattedMessage } from '@kbn/i18n/react'; + +export const MissingDecisionPathCallout = () => { + return ( + + + + ); +}; diff --git a/x-pack/plugins/ml/public/application/components/data_grid/feature_importance/use_classification_path_data.tsx b/x-pack/plugins/ml/public/application/components/data_grid/feature_importance/use_classification_path_data.tsx new file mode 100644 index 0000000000000..90216c4a58ffc --- /dev/null +++ b/x-pack/plugins/ml/public/application/components/data_grid/feature_importance/use_classification_path_data.tsx @@ -0,0 +1,173 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ + +import { useMemo } from 'react'; +import { i18n } from '@kbn/i18n'; +import { FeatureImportance, TopClasses } from '../../../../../common/types/feature_importance'; +import { ExtendedFeatureImportance } from './decision_path_popover'; + +export type DecisionPathPlotData = Array<[string, number, number]>; + +interface UseDecisionPathDataParams { + featureImportance: FeatureImportance[]; + baseline?: number; + predictedValue?: string | number | undefined; + topClasses?: TopClasses; +} + +interface RegressionDecisionPathProps { + baseline?: number; + predictedValue?: number | undefined; + featureImportance: FeatureImportance[]; + topClasses?: TopClasses; +} +const FEATURE_NAME = 'feature_name'; +const FEATURE_IMPORTANCE = 'importance'; + +export const isDecisionPathData = (decisionPathData: any): boolean => { + return ( + Array.isArray(decisionPathData) && + decisionPathData.length > 0 && + decisionPathData[0].length === 3 + ); +}; + +// cast to 'True' | 'False' | value to match Eui display +export const getStringBasedClassName = (v: string | boolean | undefined | number): string => { + if (v === undefined) { + return ''; + } + if (typeof v === 'boolean') { + return v ? 'True' : 'False'; + } + if (typeof v === 'number') { + return v.toString(); + } + return v; +}; + +export const useDecisionPathData = ({ + baseline, + featureImportance, + predictedValue, +}: UseDecisionPathDataParams): { decisionPathData: DecisionPathPlotData | undefined } => { + const decisionPathData = useMemo(() => { + return baseline + ? buildRegressionDecisionPathData({ + baseline, + featureImportance, + predictedValue: predictedValue as number | undefined, + }) + : buildClassificationDecisionPathData({ + featureImportance, + currentClass: predictedValue as string | undefined, + }); + }, [baseline, featureImportance, predictedValue]); + + return { decisionPathData }; +}; + +export const buildDecisionPathData = (featureImportance: ExtendedFeatureImportance[]) => { + const finalResult: DecisionPathPlotData = featureImportance + // sort so absolute importance so it goes from bottom (baseline) to top + .sort( + (a: ExtendedFeatureImportance, b: ExtendedFeatureImportance) => + b.absImportance! - a.absImportance! + ) + .map((d) => [d[FEATURE_NAME] as string, d[FEATURE_IMPORTANCE] as number, NaN]); + + // start at the baseline and end at predicted value + // for regression, cumulativeSum should add up to baseline + let cumulativeSum = 0; + for (let i = featureImportance.length - 1; i >= 0; i--) { + cumulativeSum += finalResult[i][1]; + finalResult[i][2] = cumulativeSum; + } + return finalResult; +}; +export const buildRegressionDecisionPathData = ({ + baseline, + featureImportance, + predictedValue, +}: RegressionDecisionPathProps): DecisionPathPlotData | undefined => { + let mappedFeatureImportance: ExtendedFeatureImportance[] = featureImportance; + mappedFeatureImportance = mappedFeatureImportance.map((d) => ({ + ...d, + absImportance: Math.abs(d[FEATURE_IMPORTANCE] as number), + })); + + if (baseline && predictedValue !== undefined && Number.isFinite(predictedValue)) { + // get the adjusted importance needed for when # of fields included in c++ analysis != max allowed + // if num fields included = num features allowed exactly, adjustedImportance should be 0 + const adjustedImportance = + predictedValue - + mappedFeatureImportance.reduce( + (accumulator, currentValue) => accumulator + currentValue.importance!, + 0 + ) - + baseline; + + mappedFeatureImportance.push({ + [FEATURE_NAME]: i18n.translate( + 'xpack.ml.dataframe.analytics.decisionPathFeatureBaselineTitle', + { + defaultMessage: 'baseline', + } + ), + [FEATURE_IMPORTANCE]: baseline, + absImportance: -1, + }); + + // if the difference is small enough then no need to plot the residual feature importance + if (Math.abs(adjustedImportance) > 1e-5) { + mappedFeatureImportance.push({ + [FEATURE_NAME]: i18n.translate( + 'xpack.ml.dataframe.analytics.decisionPathFeatureOtherTitle', + { + defaultMessage: 'other', + } + ), + [FEATURE_IMPORTANCE]: adjustedImportance, + absImportance: 0, // arbitrary importance so this will be of higher importance than baseline + }); + } + } + const filteredFeatureImportance = mappedFeatureImportance.filter( + (f) => f !== undefined + ) as ExtendedFeatureImportance[]; + + return buildDecisionPathData(filteredFeatureImportance); +}; + +export const buildClassificationDecisionPathData = ({ + featureImportance, + currentClass, +}: { + featureImportance: FeatureImportance[]; + currentClass: string | undefined; +}): DecisionPathPlotData | undefined => { + if (currentClass === undefined) return []; + const mappedFeatureImportance: Array< + ExtendedFeatureImportance | undefined + > = featureImportance.map((feature) => { + const classFeatureImportance = Array.isArray(feature.classes) + ? feature.classes.find((c) => getStringBasedClassName(c.class_name) === currentClass) + : feature; + if (classFeatureImportance && typeof classFeatureImportance[FEATURE_IMPORTANCE] === 'number') { + return { + [FEATURE_NAME]: feature[FEATURE_NAME], + [FEATURE_IMPORTANCE]: classFeatureImportance[FEATURE_IMPORTANCE], + absImportance: Math.abs(classFeatureImportance[FEATURE_IMPORTANCE] as number), + }; + } + return undefined; + }); + const filteredFeatureImportance = mappedFeatureImportance.filter( + (f) => f !== undefined + ) as ExtendedFeatureImportance[]; + + return buildDecisionPathData(filteredFeatureImportance); +}; diff --git a/x-pack/plugins/ml/public/application/components/data_grid/types.ts b/x-pack/plugins/ml/public/application/components/data_grid/types.ts index 756f74c8f9302..f9ee8c37fabf7 100644 --- a/x-pack/plugins/ml/public/application/components/data_grid/types.ts +++ b/x-pack/plugins/ml/public/application/components/data_grid/types.ts @@ -74,6 +74,9 @@ export interface UseIndexDataReturnType | 'tableItems' | 'toggleChartVisibility' | 'visibleColumns' + | 'baseline' + | 'predictionFieldName' + | 'resultsField' > { renderCellValue: RenderCellValue; } @@ -105,4 +108,7 @@ export interface UseDataGridReturnType { tableItems: DataGridItem[]; toggleChartVisibility: () => void; visibleColumns: ColumnId[]; + baseline?: number; + predictionFieldName?: string; + resultsField?: string; } diff --git a/x-pack/plugins/ml/public/application/data_frame_analytics/common/analytics.ts b/x-pack/plugins/ml/public/application/data_frame_analytics/common/analytics.ts index 8ad861e616b7a..97098ea9e75c6 100644 --- a/x-pack/plugins/ml/public/application/data_frame_analytics/common/analytics.ts +++ b/x-pack/plugins/ml/public/application/data_frame_analytics/common/analytics.ts @@ -15,18 +15,19 @@ import { SavedSearchQuery } from '../../contexts/ml'; import { AnalysisConfig, ClassificationAnalysis, - OutlierAnalysis, RegressionAnalysis, + ANALYSIS_CONFIG_TYPE, } from '../../../../common/types/data_frame_analytics'; - +import { + isOutlierAnalysis, + isRegressionAnalysis, + isClassificationAnalysis, + getPredictionFieldName, + getDependentVar, + getPredictedFieldName, +} from '../../../../common/util/analytics_utils'; export type IndexPattern = string; -export enum ANALYSIS_CONFIG_TYPE { - OUTLIER_DETECTION = 'outlier_detection', - REGRESSION = 'regression', - CLASSIFICATION = 'classification', -} - export enum ANALYSIS_ADVANCED_FIELDS { ETA = 'eta', FEATURE_BAG_FRACTION = 'feature_bag_fraction', @@ -156,23 +157,6 @@ export const getAnalysisType = (analysis: AnalysisConfig): string => { return 'unknown'; }; -export const getDependentVar = ( - analysis: AnalysisConfig -): - | RegressionAnalysis['regression']['dependent_variable'] - | ClassificationAnalysis['classification']['dependent_variable'] => { - let depVar = ''; - - if (isRegressionAnalysis(analysis)) { - depVar = analysis.regression.dependent_variable; - } - - if (isClassificationAnalysis(analysis)) { - depVar = analysis.classification.dependent_variable; - } - return depVar; -}; - export const getTrainingPercent = ( analysis: AnalysisConfig ): @@ -190,24 +174,6 @@ export const getTrainingPercent = ( return trainingPercent; }; -export const getPredictionFieldName = ( - analysis: AnalysisConfig -): - | RegressionAnalysis['regression']['prediction_field_name'] - | ClassificationAnalysis['classification']['prediction_field_name'] => { - // If undefined will be defaulted to dependent_variable when config is created - let predictionFieldName; - if (isRegressionAnalysis(analysis) && analysis.regression.prediction_field_name !== undefined) { - predictionFieldName = analysis.regression.prediction_field_name; - } else if ( - isClassificationAnalysis(analysis) && - analysis.classification.prediction_field_name !== undefined - ) { - predictionFieldName = analysis.classification.prediction_field_name; - } - return predictionFieldName; -}; - export const getNumTopClasses = ( analysis: AnalysisConfig ): ClassificationAnalysis['classification']['num_top_classes'] => { @@ -238,35 +204,6 @@ export const getNumTopFeatureImportanceValues = ( return numTopFeatureImportanceValues; }; -export const getPredictedFieldName = ( - resultsField: string, - analysis: AnalysisConfig, - forSort?: boolean -) => { - // default is 'ml' - const predictionFieldName = getPredictionFieldName(analysis); - const defaultPredictionField = `${getDependentVar(analysis)}_prediction`; - const predictedField = `${resultsField}.${ - predictionFieldName ? predictionFieldName : defaultPredictionField - }`; - return predictedField; -}; - -export const isOutlierAnalysis = (arg: any): arg is OutlierAnalysis => { - const keys = Object.keys(arg); - return keys.length === 1 && keys[0] === ANALYSIS_CONFIG_TYPE.OUTLIER_DETECTION; -}; - -export const isRegressionAnalysis = (arg: any): arg is RegressionAnalysis => { - const keys = Object.keys(arg); - return keys.length === 1 && keys[0] === ANALYSIS_CONFIG_TYPE.REGRESSION; -}; - -export const isClassificationAnalysis = (arg: any): arg is ClassificationAnalysis => { - const keys = Object.keys(arg); - return keys.length === 1 && keys[0] === ANALYSIS_CONFIG_TYPE.CLASSIFICATION; -}; - export const isResultsSearchBoolQuery = (arg: any): arg is ResultsSearchBoolQuery => { if (arg === undefined) return false; const keys = Object.keys(arg); @@ -607,3 +544,13 @@ export const loadDocsCount = async ({ }; } }; + +export { + isOutlierAnalysis, + isRegressionAnalysis, + isClassificationAnalysis, + getPredictionFieldName, + ANALYSIS_CONFIG_TYPE, + getDependentVar, + getPredictedFieldName, +}; diff --git a/x-pack/plugins/ml/public/application/data_frame_analytics/common/constants.ts b/x-pack/plugins/ml/public/application/data_frame_analytics/common/constants.ts index 2f14dfdfdfca3..c2295a92af89c 100644 --- a/x-pack/plugins/ml/public/application/data_frame_analytics/common/constants.ts +++ b/x-pack/plugins/ml/public/application/data_frame_analytics/common/constants.ts @@ -3,8 +3,6 @@ * or more contributor license agreements. Licensed under the Elastic License; * you may not use this file except in compliance with the Elastic License. */ - -export const DEFAULT_RESULTS_FIELD = 'ml'; export const FEATURE_IMPORTANCE = 'feature_importance'; export const FEATURE_INFLUENCE = 'feature_influence'; export const TOP_CLASSES = 'top_classes'; diff --git a/x-pack/plugins/ml/public/application/data_frame_analytics/common/fields.ts b/x-pack/plugins/ml/public/application/data_frame_analytics/common/fields.ts index 847aefefbc6c8..f9c9bf26a9d16 100644 --- a/x-pack/plugins/ml/public/application/data_frame_analytics/common/fields.ts +++ b/x-pack/plugins/ml/public/application/data_frame_analytics/common/fields.ts @@ -4,17 +4,16 @@ * you may not use this file except in compliance with the Elastic License. */ +import { getNumTopClasses, getNumTopFeatureImportanceValues } from './analytics'; +import { Field } from '../../../../common/types/fields'; import { - getNumTopClasses, - getNumTopFeatureImportanceValues, getPredictedFieldName, getDependentVar, getPredictionFieldName, isClassificationAnalysis, isOutlierAnalysis, isRegressionAnalysis, -} from './analytics'; -import { Field } from '../../../../common/types/fields'; +} from '../../../../common/util/analytics_utils'; import { ES_FIELD_TYPES, KBN_FIELD_TYPES } from '../../../../../../../src/plugins/data/public'; import { newJobCapsService } from '../../services/new_job_capabilities_service'; diff --git a/x-pack/plugins/ml/public/application/data_frame_analytics/pages/analytics_exploration/components/classification_exploration/classification_exploration.tsx b/x-pack/plugins/ml/public/application/data_frame_analytics/pages/analytics_exploration/components/classification_exploration/classification_exploration.tsx index ccac9a697210b..2e3a5d89367ce 100644 --- a/x-pack/plugins/ml/public/application/data_frame_analytics/pages/analytics_exploration/components/classification_exploration/classification_exploration.tsx +++ b/x-pack/plugins/ml/public/application/data_frame_analytics/pages/analytics_exploration/components/classification_exploration/classification_exploration.tsx @@ -9,7 +9,6 @@ import React, { FC } from 'react'; import { i18n } from '@kbn/i18n'; import { ExplorationPageWrapper } from '../exploration_page_wrapper'; - import { EvaluatePanel } from './evaluate_panel'; interface Props { diff --git a/x-pack/plugins/ml/public/application/data_frame_analytics/pages/analytics_exploration/components/exploration_page_wrapper/exploration_page_wrapper.tsx b/x-pack/plugins/ml/public/application/data_frame_analytics/pages/analytics_exploration/components/exploration_page_wrapper/exploration_page_wrapper.tsx index 34ff36c59fa6c..84b44ef0d349f 100644 --- a/x-pack/plugins/ml/public/application/data_frame_analytics/pages/analytics_exploration/components/exploration_page_wrapper/exploration_page_wrapper.tsx +++ b/x-pack/plugins/ml/public/application/data_frame_analytics/pages/analytics_exploration/components/exploration_page_wrapper/exploration_page_wrapper.tsx @@ -51,7 +51,6 @@ export const ExplorationPageWrapper: FC = ({ jobId, title, EvaluatePanel /> ); } - return ( <> {isLoadingJobConfig === true && jobConfig === undefined && } diff --git a/x-pack/plugins/ml/public/application/data_frame_analytics/pages/analytics_exploration/components/exploration_results_table/exploration_results_table.tsx b/x-pack/plugins/ml/public/application/data_frame_analytics/pages/analytics_exploration/components/exploration_results_table/exploration_results_table.tsx index 8395a11bd6fda..eea579ef1d064 100644 --- a/x-pack/plugins/ml/public/application/data_frame_analytics/pages/analytics_exploration/components/exploration_results_table/exploration_results_table.tsx +++ b/x-pack/plugins/ml/public/application/data_frame_analytics/pages/analytics_exploration/components/exploration_results_table/exploration_results_table.tsx @@ -28,6 +28,8 @@ import { INDEX_STATUS, SEARCH_SIZE, defaultSearchQuery, + getAnalysisType, + ANALYSIS_CONFIG_TYPE, } from '../../../../common'; import { getTaskStateBadge } from '../../../analytics_management/components/analytics_list/use_columns'; import { DATA_FRAME_TASK_STATE } from '../../../analytics_management/components/analytics_list/common'; @@ -36,6 +38,7 @@ import { ExplorationQueryBar } from '../exploration_query_bar'; import { IndexPatternPrompt } from '../index_pattern_prompt'; import { useExplorationResults } from './use_exploration_results'; +import { useMlKibana } from '../../../../../contexts/kibana'; const showingDocs = i18n.translate( 'xpack.ml.dataframe.analytics.explorationResults.documentsShownHelpText', @@ -70,18 +73,27 @@ export const ExplorationResultsTable: FC = React.memo( setEvaluateSearchQuery, title, }) => { + const { + services: { + mlServices: { mlApiServices }, + }, + } = useMlKibana(); const [searchQuery, setSearchQuery] = useState(defaultSearchQuery); useEffect(() => { setEvaluateSearchQuery(searchQuery); }, [JSON.stringify(searchQuery)]); + const analysisType = getAnalysisType(jobConfig.analysis); + const classificationData = useExplorationResults( indexPattern, jobConfig, searchQuery, - getToastNotifications() + getToastNotifications(), + mlApiServices ); + const docFieldsCount = classificationData.columnsWithCharts.length; const { columnsWithCharts, @@ -94,7 +106,6 @@ export const ExplorationResultsTable: FC = React.memo( if (jobConfig === undefined || classificationData === undefined) { return null; } - // if it's a searchBar syntax error leave the table visible so they can try again if (status === INDEX_STATUS.ERROR && !errorMessage.includes('failed to create query')) { return ( @@ -184,6 +195,7 @@ export const ExplorationResultsTable: FC = React.memo( {...classificationData} dataTestSubj="mlExplorationDataGrid" toastNotifications={getToastNotifications()} + analysisType={(analysisType as unknown) as ANALYSIS_CONFIG_TYPE} /> diff --git a/x-pack/plugins/ml/public/application/data_frame_analytics/pages/analytics_exploration/components/exploration_results_table/use_exploration_results.ts b/x-pack/plugins/ml/public/application/data_frame_analytics/pages/analytics_exploration/components/exploration_results_table/use_exploration_results.ts index 8d53214d23d47..a56345017258e 100644 --- a/x-pack/plugins/ml/public/application/data_frame_analytics/pages/analytics_exploration/components/exploration_results_table/use_exploration_results.ts +++ b/x-pack/plugins/ml/public/application/data_frame_analytics/pages/analytics_exploration/components/exploration_results_table/use_exploration_results.ts @@ -4,12 +4,14 @@ * you may not use this file except in compliance with the Elastic License. */ -import { useEffect, useMemo } from 'react'; +import { useCallback, useEffect, useMemo, useState } from 'react'; import { EuiDataGridColumn } from '@elastic/eui'; import { CoreSetup } from 'src/core/public'; +import { i18n } from '@kbn/i18n'; +import { MlApiServices } from '../../../../../services/ml_api_service'; import { IndexPattern } from '../../../../../../../../../../src/plugins/data/public'; import { DataLoader } from '../../../../../datavisualizer/index_based/data_loader'; @@ -23,21 +25,26 @@ import { UseIndexDataReturnType, } from '../../../../../components/data_grid'; import { SavedSearchQuery } from '../../../../../contexts/ml'; - import { getIndexData, getIndexFields, DataFrameAnalyticsConfig } from '../../../../common'; import { - DEFAULT_RESULTS_FIELD, - FEATURE_IMPORTANCE, - TOP_CLASSES, -} from '../../../../common/constants'; + getPredictionFieldName, + getDefaultPredictionFieldName, +} from '../../../../../../../common/util/analytics_utils'; +import { FEATURE_IMPORTANCE, TOP_CLASSES } from '../../../../common/constants'; +import { DEFAULT_RESULTS_FIELD } from '../../../../../../../common/constants/data_frame_analytics'; import { sortExplorationResultsFields, ML__ID_COPY } from '../../../../common/fields'; +import { isRegressionAnalysis } from '../../../../common/analytics'; +import { extractErrorMessage } from '../../../../../../../common/util/errors'; export const useExplorationResults = ( indexPattern: IndexPattern | undefined, jobConfig: DataFrameAnalyticsConfig | undefined, searchQuery: SavedSearchQuery, - toastNotifications: CoreSetup['notifications']['toasts'] + toastNotifications: CoreSetup['notifications']['toasts'], + mlApiServices: MlApiServices ): UseIndexDataReturnType => { + const [baseline, setBaseLine] = useState(); + const needsDestIndexFields = indexPattern !== undefined && indexPattern.title === jobConfig?.source.index[0]; @@ -52,7 +59,6 @@ export const useExplorationResults = ( ) ); } - const dataGrid = useDataGrid( columns, 25, @@ -107,16 +113,60 @@ export const useExplorationResults = ( jobConfig?.dest.index, JSON.stringify([searchQuery, dataGrid.visibleColumns]), ]); + const predictionFieldName = useMemo(() => { + if (jobConfig) { + return ( + getPredictionFieldName(jobConfig.analysis) ?? + getDefaultPredictionFieldName(jobConfig.analysis) + ); + } + return undefined; + }, [jobConfig]); + + const getAnalyticsBaseline = useCallback(async () => { + try { + if ( + jobConfig !== undefined && + jobConfig.analysis !== undefined && + isRegressionAnalysis(jobConfig.analysis) + ) { + const result = await mlApiServices.dataFrameAnalytics.getAnalyticsBaseline(jobConfig.id); + if (result?.baseline) { + setBaseLine(result.baseline); + } + } + } catch (e) { + const error = extractErrorMessage(e); + + toastNotifications.addDanger({ + title: i18n.translate( + 'xpack.ml.dataframe.analytics.explorationResults.baselineErrorMessageToast', + { + defaultMessage: 'An error occurred getting feature importance baseline', + } + ), + text: error, + }); + } + }, [mlApiServices, jobConfig]); + + useEffect(() => { + getAnalyticsBaseline(); + }, [jobConfig]); + const resultsField = jobConfig?.dest.results_field ?? DEFAULT_RESULTS_FIELD; const renderCellValue = useRenderCellValue( indexPattern, dataGrid.pagination, dataGrid.tableItems, - jobConfig?.dest.results_field ?? DEFAULT_RESULTS_FIELD + resultsField ); return { ...dataGrid, renderCellValue, + baseline, + predictionFieldName, + resultsField, }; }; diff --git a/x-pack/plugins/ml/public/application/data_frame_analytics/pages/analytics_exploration/components/outlier_exploration/use_outlier_data.ts b/x-pack/plugins/ml/public/application/data_frame_analytics/pages/analytics_exploration/components/outlier_exploration/use_outlier_data.ts index 24649ae5f1e71..151e5ea4e6feb 100644 --- a/x-pack/plugins/ml/public/application/data_frame_analytics/pages/analytics_exploration/components/outlier_exploration/use_outlier_data.ts +++ b/x-pack/plugins/ml/public/application/data_frame_analytics/pages/analytics_exploration/components/outlier_exploration/use_outlier_data.ts @@ -29,7 +29,8 @@ import { SavedSearchQuery } from '../../../../../contexts/ml'; import { getToastNotifications } from '../../../../../util/dependency_cache'; import { getIndexData, getIndexFields, DataFrameAnalyticsConfig } from '../../../../common'; -import { DEFAULT_RESULTS_FIELD, FEATURE_INFLUENCE } from '../../../../common/constants'; +import { FEATURE_INFLUENCE } from '../../../../common/constants'; +import { DEFAULT_RESULTS_FIELD } from '../../../../../../../common/constants/data_frame_analytics'; import { sortExplorationResultsFields, ML__ID_COPY } from '../../../../common/fields'; import { getFeatureCount, getOutlierScoreFieldName } from './common'; diff --git a/x-pack/plugins/ml/public/application/data_frame_analytics/pages/analytics_management/components/action_clone/clone_action_name.tsx b/x-pack/plugins/ml/public/application/data_frame_analytics/pages/analytics_management/components/action_clone/clone_action_name.tsx index 60c699ba0d370..ce24892c9de45 100644 --- a/x-pack/plugins/ml/public/application/data_frame_analytics/pages/analytics_management/components/action_clone/clone_action_name.tsx +++ b/x-pack/plugins/ml/public/application/data_frame_analytics/pages/analytics_management/components/action_clone/clone_action_name.tsx @@ -12,7 +12,7 @@ import { IIndexPattern } from 'src/plugins/data/common'; import { DeepReadonly } from '../../../../../../../common/types/common'; import { DataFrameAnalyticsConfig, isOutlierAnalysis } from '../../../../common'; import { isClassificationAnalysis, isRegressionAnalysis } from '../../../../common/analytics'; -import { DEFAULT_RESULTS_FIELD } from '../../../../common/constants'; +import { DEFAULT_RESULTS_FIELD } from '../../../../../../../common/constants/data_frame_analytics'; import { useMlKibana, useNavigateToPath } from '../../../../../contexts/kibana'; import { DEFAULT_NUM_TOP_FEATURE_IMPORTANCE_VALUES } from '../../hooks/use_create_analytics_form'; import { State } from '../../hooks/use_create_analytics_form/state'; diff --git a/x-pack/plugins/ml/public/application/services/ml_api_service/data_frame_analytics.ts b/x-pack/plugins/ml/public/application/services/ml_api_service/data_frame_analytics.ts index 7de39d91047ef..434200d0383f5 100644 --- a/x-pack/plugins/ml/public/application/services/ml_api_service/data_frame_analytics.ts +++ b/x-pack/plugins/ml/public/application/services/ml_api_service/data_frame_analytics.ts @@ -135,4 +135,10 @@ export const dataFrameAnalytics = { method: 'GET', }); }, + getAnalyticsBaseline(analyticsId: string) { + return http({ + path: `${basePath()}/data_frame/analytics/${analyticsId}/baseline`, + method: 'POST', + }); + }, }; diff --git a/x-pack/plugins/ml/server/models/data_frame_analytics/feature_importance.ts b/x-pack/plugins/ml/server/models/data_frame_analytics/feature_importance.ts new file mode 100644 index 0000000000000..94f54a5654873 --- /dev/null +++ b/x-pack/plugins/ml/server/models/data_frame_analytics/feature_importance.ts @@ -0,0 +1,69 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ + +import { IScopedClusterClient } from 'kibana/server'; +import { + getDefaultPredictionFieldName, + getPredictionFieldName, + isRegressionAnalysis, +} from '../../../common/util/analytics_utils'; +import { DEFAULT_RESULTS_FIELD } from '../../../common/constants/data_frame_analytics'; +// Obtains data for the data frame analytics feature importance functionalities +// such as baseline, decision paths, or importance summary. +export function analyticsFeatureImportanceProvider({ + asInternalUser, + asCurrentUser, +}: IScopedClusterClient) { + async function getRegressionAnalyticsBaseline(analyticsId: string): Promise { + const { body } = await asInternalUser.ml.getDataFrameAnalytics({ + id: analyticsId, + }); + const jobConfig = body.data_frame_analytics[0]; + if (!isRegressionAnalysis) return undefined; + const destinationIndex = jobConfig.dest.index; + const predictionFieldName = getPredictionFieldName(jobConfig.analysis); + const mlResultsField = jobConfig.dest?.results_field ?? DEFAULT_RESULTS_FIELD; + const predictedField = `${mlResultsField}.${ + predictionFieldName ? predictionFieldName : getDefaultPredictionFieldName(jobConfig.analysis) + }`; + const isTrainingField = `${mlResultsField}.is_training`; + + const params = { + index: destinationIndex, + size: 0, + body: { + query: { + bool: { + filter: [ + { + term: { + [isTrainingField]: true, + }, + }, + ], + }, + }, + aggs: { + featureImportanceBaseline: { + avg: { + field: predictedField, + }, + }, + }, + }, + }; + let baseline; + const { body: aggregationResult } = await asCurrentUser.search(params); + if (aggregationResult) { + baseline = aggregationResult.aggregations.featureImportanceBaseline.value; + } + return baseline; + } + + return { + getRegressionAnalyticsBaseline, + }; +} diff --git a/x-pack/plugins/ml/server/routes/data_frame_analytics.ts b/x-pack/plugins/ml/server/routes/data_frame_analytics.ts index dea4803e8275e..7606420eacefc 100644 --- a/x-pack/plugins/ml/server/routes/data_frame_analytics.ts +++ b/x-pack/plugins/ml/server/routes/data_frame_analytics.ts @@ -20,6 +20,7 @@ import { import { IndexPatternHandler } from '../models/data_frame_analytics/index_patterns'; import { DeleteDataFrameAnalyticsWithIndexStatus } from '../../common/types/data_frame_analytics'; import { getAuthorizationHeader } from '../lib/request_authorization'; +import { analyticsFeatureImportanceProvider } from '../models/data_frame_analytics/feature_importance'; function getIndexPatternId(context: RequestHandlerContext, patternName: string) { const iph = new IndexPatternHandler(context.core.savedObjects.client); @@ -545,4 +546,38 @@ export function dataFrameAnalyticsRoutes({ router, mlLicense }: RouteInitializat } }) ); + + /** + * @apiGroup DataFrameAnalytics + * + * @api {get} /api/ml/data_frame/analytics/baseline Get analytics's feature importance baseline + * @apiName GetDataFrameAnalyticsBaseline + * @apiDescription Returns the baseline for data frame analytics job. + * + * @apiSchema (params) analyticsIdSchema + */ + router.post( + { + path: '/api/ml/data_frame/analytics/{analyticsId}/baseline', + validate: { + params: analyticsIdSchema, + }, + options: { + tags: ['access:ml:canGetDataFrameAnalytics'], + }, + }, + mlLicense.fullLicenseAPIGuard(async ({ client, request, response }) => { + try { + const { analyticsId } = request.params; + const { getRegressionAnalyticsBaseline } = analyticsFeatureImportanceProvider(client); + const baseline = await getRegressionAnalyticsBaseline(analyticsId); + + return response.ok({ + body: { baseline }, + }); + } catch (e) { + return response.customError(wrapError(e)); + } + }) + ); } diff --git a/x-pack/test/functional/services/ml/data_frame_analytics_creation.ts b/x-pack/test/functional/services/ml/data_frame_analytics_creation.ts index ffa1d9fd46c75..e01e065867ac7 100644 --- a/x-pack/test/functional/services/ml/data_frame_analytics_creation.ts +++ b/x-pack/test/functional/services/ml/data_frame_analytics_creation.ts @@ -10,25 +10,9 @@ import { FtrProviderContext } from '../../ftr_provider_context'; import { MlCommonUI } from './common_ui'; import { MlApi } from './api'; import { - ClassificationAnalysis, - RegressionAnalysis, -} from '../../../../plugins/ml/common/types/data_frame_analytics'; - -enum ANALYSIS_CONFIG_TYPE { - OUTLIER_DETECTION = 'outlier_detection', - REGRESSION = 'regression', - CLASSIFICATION = 'classification', -} - -const isRegressionAnalysis = (arg: any): arg is RegressionAnalysis => { - const keys = Object.keys(arg); - return keys.length === 1 && keys[0] === ANALYSIS_CONFIG_TYPE.REGRESSION; -}; - -const isClassificationAnalysis = (arg: any): arg is ClassificationAnalysis => { - const keys = Object.keys(arg); - return keys.length === 1 && keys[0] === ANALYSIS_CONFIG_TYPE.CLASSIFICATION; -}; + isRegressionAnalysis, + isClassificationAnalysis, +} from '../../../../plugins/ml/common/util/analytics_utils'; export function MachineLearningDataFrameAnalyticsCreationProvider( { getService }: FtrProviderContext,