Skip to content

Commit

Permalink
Ensemble calibrate error charts (#5755)
Browse files Browse the repository at this point in the history
  • Loading branch information
jryu01 authored Dec 5, 2024
1 parent 8b4e8e0 commit dd10f97
Show file tree
Hide file tree
Showing 6 changed files with 183 additions and 27 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -51,12 +51,12 @@ export function getErrorData(
groundTruth: DataArray,
simulationData: DataArray,
mapping: CalibrateMap[],
timestampColName: string
timestampColName: string,
pyciemssMap: Record<string, string>
) {
const errors: DataArray = [];
if (simulationData.length === 0 || groundTruth.length === 0 || !timestampColName) return errors;
const pyciemssMap = parsePyCiemssMap(simulationData[0]);

if (simulationData.length === 0 || groundTruth.length === 0 || !timestampColName || _.isEmpty(pyciemssMap))
return errors;
const datasetVariables = mapping.map((ele) => ele.datasetVariable);
const relevantGroundTruthColumns = Object.keys(groundTruth[0]).filter(
(variable) => datasetVariables.includes(variable) && variable !== timestampColName
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -733,7 +733,8 @@ const errorData = computed<DataArray>(() =>
groundTruthData.value,
runResult.value,
selectedOutputMapping.value,
selectedOutputTimestampColName.value
selectedOutputTimestampColName.value,
pyciemssMap.value
)
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ import _ from 'lodash';
import { createForecastChart, AUTOSIZE } from '@/services/charts';
import {
DataArray,
extractModelConfigIdsInOrder,
getEnsembleResultModelConfigMap,
getRunResultCSV,
getSimulation,
Expand All @@ -10,12 +11,13 @@ import {
import { EnsembleModelConfigs } from '@/types/Types';
import { WorkflowNode } from '@/types/workflow';
import { getActiveOutput } from '@/components/workflow/util';
import { CalibrateMap } from '@/services/calibrate-workflow';
import {
CalibrateEnsembleCiemssOperationState,
CalibrateEnsembleMappingRow,
CalibrateEnsembleWeights
} from './calibrate-ensemble-ciemss-operation';
import { mergeResults, renameFnGenerator } from '../calibrate-ciemss/calibrate-utils';
import { getErrorData, mergeResults, renameFnGenerator } from '../calibrate-ciemss/calibrate-utils';

export async function getLossValuesFromSimulation(calibrationId: string) {
if (!calibrationId) return [];
Expand Down Expand Up @@ -143,3 +145,43 @@ export function buildChartData(
});
return { ...outputData, translationMap };
}

export interface EnsembleErrorData {
ensemble: DataArray;
[modelConfigId: string]: DataArray;
}

// Get the error data for the ensemble calibration
export function getEnsembleErrorData(
groundTruth: DataArray,
simulationData: DataArray,
mapping: CalibrateEnsembleMappingRow[],
pyciemssMap: Record<string, string>
): EnsembleErrorData {
const errorData: EnsembleErrorData = { ensemble: [] };
const timestampColName = mapping.find((m) => m.newName === 'timepoint_id')?.datasetMapping ?? '';
const mappingWithoutTimeCol = mapping.filter((m) => m.newName !== 'timepoint_id');
// Error data for the ensemble
const calibrateMappings = mappingWithoutTimeCol.map(
(m) =>
({
datasetVariable: m.datasetMapping,
modelVariable: m.datasetMapping
}) as CalibrateMap
);
errorData.ensemble = getErrorData(groundTruth, simulationData, calibrateMappings, timestampColName, pyciemssMap);

// Error data for each model
const modelConfigIds = extractModelConfigIdsInOrder(pyciemssMap);
modelConfigIds.forEach((configId) => {
const cMapping = mappingWithoutTimeCol.map(
(m) =>
({
datasetVariable: m.datasetMapping,
modelVariable: `${configId}/${m.modelConfigurationMappings[configId]}`
}) as CalibrateMap
);
errorData[configId] = getErrorData(groundTruth, simulationData, cMapping, timestampColName, pyciemssMap);
});
return errorData;
}
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@
<tera-drilldown-section>
<section class="pb-3 px-2">
<div class="mx-2" ref="chartWidthDiv"></div>
<Accordion multiple :active-index="[0, 1, 2]">
<Accordion multiple :active-index="[0, 1, 2, 3]">
<!-- <AccordionTab header="Summary">
</AccordionTab> -->
<AccordionTab v-if="node.state.showLossChart" header="Loss">
Expand All @@ -205,7 +205,7 @@
/>
</AccordionTab>
<template v-if="!isRunInProgress">
<AccordionTab header="Ensemble variables over time">
<AccordionTab v-if="selectedEnsembleVariableSettings.length > 0" header="Ensemble variables over time">
<div class="flex flex-row" v-for="setting of selectedEnsembleVariableSettings" :key="setting.id">
<vega-chart
v-for="(spec, index) of ensembleVariableCharts[setting.id]"
Expand All @@ -216,6 +216,17 @@
/>
</div>
</AccordionTab>
<AccordionTab v-if="selectedErrorVariableSettings.length > 0" header="Error">
<div class="flex flex-row">
<vega-chart
v-for="(spec, index) of errorCharts"
:key="index"
:expandable="() => onExpandErrorChart(index)"
:are-embed-actions-visible="true"
:visualization-spec="spec"
/>
</div>
</AccordionTab>
<AccordionTab v-if="node.state.showModelWeightsCharts" header="Model weights">
<div class="flex flex-row">
<vega-chart
Expand Down Expand Up @@ -278,6 +289,17 @@
@toggle-ensemble-variable-setting-option="updateEnsembleVariableSettingOption"
/>
<Divider />
<tera-chart-settings
:title="'Error'"
:settings="chartSettings"
:type="ChartSettingType.ERROR_DISTRIBUTION"
:select-options="ensembleVariables"
:selected-options="selectedErrorVariableSettings.map((s) => s.selectedVariables[0])"
@open="activeChartSettings = $event"
@remove="removeChartSettings"
@selection-change="updateChartSettings"
/>
<Divider />
<h5>Model Weights</h5>
<tera-checkbox
label="Show distributions in charts"
Expand Down Expand Up @@ -358,7 +380,9 @@ import {
formatCalibrateModelConfigurations,
getSelectedOutputEnsembleMapping,
fetchOutputData,
buildChartData
buildChartData,
getEnsembleErrorData,
EnsembleErrorData
} from './calibrate-ensemble-util';
const props = defineProps<{
Expand Down Expand Up @@ -579,23 +603,38 @@ const {
removeChartSettings,
updateChartSettings,
selectedEnsembleVariableSettings,
selectedErrorVariableSettings,
updateEnsembleVariableSettingOption
} = useChartSettings(props, emit);
const { generateAnnotation, getChartAnnotationsByChartId, useEnsembleVariableCharts, useWeightsDistributionCharts } =
useCharts(
props.node.id,
null,
allModelConfigurations,
computed(() => buildChartData(outputData.value, selectedOutputMapping.value)),
chartSize,
null,
selectedOutputMapping
);
const {
generateAnnotation,
getChartAnnotationsByChartId,
useEnsembleVariableCharts,
useWeightsDistributionCharts,
useEnsembleErrorCharts
} = useCharts(
props.node.id,
null,
allModelConfigurations,
computed(() => buildChartData(outputData.value, selectedOutputMapping.value)),
chartSize,
null,
selectedOutputMapping
);
const errorData = computed<EnsembleErrorData>(() =>
getEnsembleErrorData(
groundTruthData.value,
outputData.value?.result ?? [],
selectedOutputMapping.value,
outputData.value?.pyciemssMap ?? {}
)
);
const ensembleVariables = computed(() => getSelectedOutputEnsembleMapping(props.node, false).map((d) => d.newName));
const ensembleVariableCharts = useEnsembleVariableCharts(selectedEnsembleVariableSettings, groundTruthData);
const weightsDistributionCharts = useWeightsDistributionCharts();
const { errorCharts, onExpandErrorChart } = useEnsembleErrorCharts(selectedErrorVariableSettings, errorData);
// --------------------------------------------------------
watch(
Expand Down
75 changes: 70 additions & 5 deletions packages/client/hmi-client/src/composables/useCharts.ts
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import {
} from '@/components/workflow/ops/calibrate-ensemble-ciemss/calibrate-ensemble-ciemss-operation';
import { SimulateEnsembleMappingRow } from '@/components/workflow/ops/simulate-ensemble-ciemss/simulate-ensemble-ciemss-operation';
import { getModelConfigName } from '@/services/model-configurations';
import { EnsembleErrorData } from '@/components/workflow/ops/calibrate-ensemble-ciemss/calibrate-ensemble-util';
import { useChartAnnotations } from './useChartAnnotations';

export interface ChartData {
Expand Down Expand Up @@ -457,10 +458,10 @@ export function useCharts(
if (!chartSettings.value.length) return [];
const variables = chartSettings.value
.map((s) => s.selectedVariables[0])
.map((variable) => ({
field: modelVarToDatasetVar(mapping?.value ?? [], variable) as string,
label: variable
}))
.map((variable) => {
const field = modelVarToDatasetVar(mapping?.value ?? [], variable) as string;
return { field, label: `${variable}, ${field}` };
})
.filter((v) => !!v.field);
return variables;
});
Expand Down Expand Up @@ -498,6 +499,69 @@ export function useCharts(
};
};

// Create ensemble calibrate error charts based on chart settings
const useEnsembleErrorCharts = (
chartSettings: ComputedRef<ChartSetting[]>,
errorData: ComputedRef<EnsembleErrorData>
) => {
const getErrorChartVariables = (configId: string) => {
const variables = chartSettings.value
.map((s) => s.selectedVariables[0])
.map((variable) => {
const modelVarName = configId
? getModelConfigVariable(<EnsembleVariableMappings>mapping?.value ?? [], variable, configId)
: variable;
const field = modelVarToDatasetVar(mapping?.value ?? [], variable) as string;
return { field, label: `${modelVarName}, ${field}` };
})
.filter((v) => !!v.field);
return variables;
};

const errorCharts = computed(() => {
if (!isChartReadyToBuild.value) return [];
const data = [errorData.value.ensemble]; // First item is always ensemble error data, and the rest are model error data
const modelConfigIds = extractModelConfigIdsInOrder(chartData.value?.pyciemssMap ?? {});
modelConfigIds.forEach((configId) => {
if (errorData.value[configId]) data.push(errorData.value[configId]);
});
const errorChartSpecs = data.map((ed, index) => {
const spec = createErrorChart(ed, {
title: '',
width: chartSize.value.width / data.length - 30, // Note: error chart adds extra 30px padding on top of the provided width so we subtract 30px to make the chart fit the container
variables: getErrorChartVariables(modelConfigIds[index - 1] ?? ''),
xAxisTitle: 'Mean absolute (MAE)',
color: CATEGORICAL_SCHEME[index % CATEGORICAL_SCHEME.length]
});
return spec;
});
return errorChartSpecs;
});

const onExpandErrorChart = (chartSpecIndex: number) => {
if (!isChartReadyToBuild.value) return {};
const modelConfigIds = extractModelConfigIdsInOrder(chartData.value?.pyciemssMap ?? {});
const errorDataKeys = ['ensemble', ...modelConfigIds];
// Customize the chart size by modifying the spec before expanding the chart
const spec = createErrorChart(errorData.value[errorDataKeys[chartSpecIndex]], {
title: '',
width: window.innerWidth / 1.5,
height: 230,
boxPlotHeight: 50,
areaChartHeight: 150,
variables: getErrorChartVariables(modelConfigIds[chartSpecIndex - 1] ?? ''),
xAxisTitle: 'Mean absolute (MAE)',
color: CATEGORICAL_SCHEME[chartSpecIndex % CATEGORICAL_SCHEME.length]
});
return spec as VisualizationSpec;
};

return {
errorCharts,
onExpandErrorChart
};
};

// Create parameter distribution charts based on chart settings
const useParameterDistributionCharts = (chartSettings: ComputedRef<ChartSetting[]>) => {
const parameterDistributionCharts = computed(() => {
Expand Down Expand Up @@ -594,6 +658,7 @@ export function useCharts(
useEnsembleVariableCharts,
useErrorChart,
useParameterDistributionCharts,
useWeightsDistributionCharts
useWeightsDistributionCharts,
useEnsembleErrorCharts
};
}
15 changes: 12 additions & 3 deletions packages/client/hmi-client/src/services/charts.ts
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ export interface ErrorChartOptions extends Omit<BaseChartOptions, 'height' | 'yA
height?: number;
areaChartHeight?: number;
boxPlotHeight?: number;
color?: string;
variables: { field: string; label?: string }[];
}

Expand Down Expand Up @@ -92,8 +93,8 @@ export function createErrorChart(dataset: Record<string, any>[], options: ErrorC
const labelFontWeight = 'normal';
const globalFont = 'Figtree';

const areaChartColor = '#1B8073';
const dotColor = '#67B5AC';
const areaChartColor = options.color ?? '#1B8073';
const dotColor = options.color ?? '#1B8073';
const boxPlotColor = '#000';

const width = options.width;
Expand Down Expand Up @@ -144,6 +145,7 @@ export function createErrorChart(dataset: Record<string, any>[], options: ErrorC
},
point: {
color: dotColor,
opacity: 0.7,
filled: true
},
boxplot: {
Expand Down Expand Up @@ -217,7 +219,14 @@ export function createErrorChart(dataset: Record<string, any>[], options: ErrorC
y: {
field: 'Variable Label',
scale: { range: [boxPlotYPosition, boxPlotYPosition] },
axis: { grid: true, labels: true, orient: 'left', offset: 5 }
axis: {
grid: true,
labels: true,
orient: 'left',
offset: 5,
labelAngle: -90,
labelLimit: areaChartHeight + boxPlotHeight + gap
}
}
}
},
Expand Down

0 comments on commit dd10f97

Please sign in to comment.