Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Apply chart annotations to the forecast charts on calibrate operator #4804

Merged
merged 14 commits into from
Sep 19, 2024
Original file line number Diff line number Diff line change
Expand Up @@ -203,31 +203,35 @@
<section ref="outputPanel" v-if="modelConfig && csvAsset">
<h5>Parameters</h5>
<br />
<template v-for="param of selectedParameters" :key="param">
<template v-for="setting of selectedParameterSettings" :key="setting.id">
<vega-chart
expandable
:are-embed-actions-visible="true"
:visualization-spec="preparedDistributionCharts[param].histogram"
:visualization-spec="preparedDistributionCharts[setting.selectedVariables[0]].histogram"
>
<template v-slot:footer>
<table class="distribution-table">
<thead>
<tr>
<th scope="col"></th>
<th scope="col">{{ preparedDistributionCharts[param].stat.header[0] }}</th>
<th scope="col">{{ preparedDistributionCharts[param].stat.header[1] }}</th>
<th scope="col">
{{ preparedDistributionCharts[setting.selectedVariables[0]].stat.header[0] }}
</th>
<th scope="col">
{{ preparedDistributionCharts[setting.selectedVariables[0]].stat.header[1] }}
</th>
</tr>
</thead>
<tbody>
<tr>
<th scope="row">Mean</th>
<td>{{ preparedDistributionCharts[param].stat.mean[0] }}</td>
<td>{{ preparedDistributionCharts[param].stat.mean[1] }}</td>
<td>{{ preparedDistributionCharts[setting.selectedVariables[0]].stat.mean[0] }}</td>
<td>{{ preparedDistributionCharts[setting.selectedVariables[0]].stat.mean[1] }}</td>
</tr>
<tr>
<th scope="row">Variance</th>
<td>{{ preparedDistributionCharts[param].stat.variance[0] }}</td>
<td>{{ preparedDistributionCharts[param].stat.variance[1] }}</td>
<td>{{ preparedDistributionCharts[setting.selectedVariables[0]].stat.variance[0] }}</td>
<td>{{ preparedDistributionCharts[setting.selectedVariables[0]].stat.variance[1] }}</td>
</tr>
</tbody>
</table>
Expand All @@ -236,12 +240,16 @@
</template>
<h5>Variables</h5>
<br />
<template v-for="variable of selectedVariables" :key="variable">
<vega-chart expandable :are-embed-actions-visible="true" :visualization-spec="preparedCharts[variable]" />
<template v-for="setting of selectedVariableSettings" :key="setting.id">
<vega-chart
expandable
:are-embed-actions-visible="true"
:visualization-spec="preparedCharts[setting.selectedVariables[0]]"
/>
</template>
<h5>Errors</h5>
<vega-chart
v-if="errorData.length > 0 && selectedErrorVariables.length > 0"
v-if="errorData.length > 0 && selectedErrorVariableSettings.length > 0"
:expandable="onExpandErrorChart"
:are-embed-actions-visible="true"
:visualization-spec="errorChart"
Expand Down Expand Up @@ -276,8 +284,7 @@
"
:active-settings="activeChartSettings"
:generate-annotation="generateAnnotation"
@create-annotation="addChartAnnotation"
@delete-annotation="removeChartAnnotation"
@delete-annotation="deleteAnnotation"
@close="activeChartSettings = null"
/>
</template>
Expand All @@ -294,7 +301,10 @@
@remove="removeChartSetting"
/>
<tera-chart-control
:chart-config="{ selectedRun: 'fixme', selectedVariable: selectedParameters }"
:chart-config="{
selectedRun: 'fixme',
selectedVariable: selectedParameterSettings.map((s) => s.selectedVariables[0])
}"
:multi-select="true"
:show-remove-button="false"
:variables="Object.keys(pyciemssMap).filter((c) => modelPartTypesMap[c] === 'parameter')"
Expand All @@ -312,7 +322,10 @@
@remove="removeChartSetting"
/>
<tera-chart-control
:chart-config="{ selectedRun: 'fixme', selectedVariable: selectedVariables }"
:chart-config="{
selectedRun: 'fixme',
selectedVariable: selectedVariableSettings.map((s) => s.selectedVariables[0])
}"
:multi-select="true"
:show-remove-button="false"
:variables="
Expand All @@ -332,7 +345,10 @@
@remove="removeChartSetting"
/>
<tera-chart-control
:chart-config="{ selectedRun: 'fixme', selectedVariable: selectedErrorVariables }"
:chart-config="{
selectedRun: 'fixme',
selectedVariable: selectedErrorVariableSettings.map((s) => s.selectedVariables[0])
}"
:multi-select="true"
:show-remove-button="false"
:variables="Object.keys(pyciemssMap).filter((c) => mapping.find((d) => d.modelVariable === c))"
Expand All @@ -357,7 +373,6 @@
<script setup lang="ts">
import _ from 'lodash';
import * as vega from 'vega';
import { v4 as uuidv4 } from 'uuid';
import { csvParse, autoType, mean, variance } from 'd3';
import { computed, onMounted, ref, shallowRef, watch } from 'vue';
import Button from 'primevue/button';
Expand All @@ -366,7 +381,14 @@ import Dropdown from 'primevue/dropdown';
import Column from 'primevue/column';
import TeraInputNumber from '@/components/widgets/tera-input-number.vue';
import { CalibrateMap, setupDatasetInput, setupModelInput } from '@/services/calibrate-workflow';
import { removeChartSettingById, updateChartSettingsBySelectedVariables } from '@/services/chart-settings';
import {
deleteAnnotation,
fetchAnnotations,
generateForecastChartAnnotation,
removeChartSettingById,
saveAnnotation,
updateChartSettingsBySelectedVariables
} from '@/services/chart-settings';
import TeraDrilldown from '@/components/drilldown/tera-drilldown.vue';
import TeraDrilldownSection from '@/components/drilldown/tera-drilldown-section.vue';
import TeraProgressSpinner from '@/components/widgets/tera-progress-spinner.vue';
Expand Down Expand Up @@ -402,13 +424,19 @@ import {
import { getModelConfigurationById } from '@/services/model-configurations';

import { WorkflowNode } from '@/types/workflow';
import { createForecastChart, createHistogramChart, createErrorChart } from '@/services/charts';
import {
createForecastChart,
createHistogramChart,
createErrorChart,
applyForecastChartAnnotations
} from '@/services/charts';
import VegaChart from '@/components/widgets/VegaChart.vue';
import TeraChartControl from '@/components/workflow/tera-chart-control.vue';
import TeraInputText from '@/components/widgets/tera-input-text.vue';
import { displayNumber } from '@/utils/number';
import TeraPyciemssCancelButton from '@/components/pyciemss/tera-pyciemss-cancel-button.vue';
import TeraSaveAssetModal from '@/components/project/tera-save-asset-modal.vue';
import { useClientEvent } from '@/composables/useClientEvent';
import type { CalibrationOperationStateCiemss } from './calibrate-operation';
import { renameFnGenerator, mergeResults, getErrorData } from './calibrate-utils';

Expand Down Expand Up @@ -560,50 +588,44 @@ const outputPanel = ref(null);
const chartSize = computed(() => drilldownChartSize(outputPanel.value));

const chartSettings = computed(() => props.node.state.chartSettings ?? []);
const selectedParameters = computed(() =>
chartSettings.value
.filter((setting) => setting.type === ChartSettingType.DISTRIBUTION_COMPARISON)
.map((setting) => setting.selectedVariables[0])
const selectedParameterSettings = computed(() =>
chartSettings.value.filter((setting) => setting.type === ChartSettingType.DISTRIBUTION_COMPARISON)
);
const selectedVariables = computed(() =>
chartSettings.value
.filter((setting) => setting.type === ChartSettingType.VARIABLE_COMPARISON)
.map((setting) => setting.selectedVariables[0])
const selectedVariableSettings = computed(() =>
chartSettings.value.filter((setting) => setting.type === ChartSettingType.VARIABLE_COMPARISON)
);
const selectedErrorVariables = computed(() =>
chartSettings.value
.filter((setting) => setting.type === ChartSettingType.ERROR_DISTRIBUTION)
.map((setting) => setting.selectedVariables[0])

const selectedErrorVariableSettings = computed(() =>
chartSettings.value.filter((setting) => setting.type === ChartSettingType.ERROR_DISTRIBUTION)
);

// --- Handle chart annotations
const chartAnnotations = ref<ChartAnnotation[]>([]);
const generateAnnotation = async (setting: ChartSetting, query: string) => {
// Generate fake annotation. The annotation generation logic for the specific chart setting should go here
// Different chart settings type may have different annotation generation logic
await new Promise((resolve) => {
setTimeout(resolve, 1000);
});
const annotation: ChartAnnotation = {
id: uuidv4(),
description: query,
nodeId: props.node.id,
outputId: '',
chartId: setting.id,
layerSpec: {},
llmGenerated: false,
metadata: {}
};
return annotation;
const updateChartAnnotations = async () => {
chartAnnotations.value = await fetchAnnotations(props.node.id);
};
const addChartAnnotation = (annotation: ChartAnnotation) => {
chartAnnotations.value.push(annotation);
};
const removeChartAnnotation = (annotationId: string) => {
const index = chartAnnotations.value.findIndex((annotation) => annotation.id === annotationId);
if (index !== -1) {
chartAnnotations.value.splice(index, 1);
}
onMounted(() => updateChartAnnotations());
useClientEvent([ClientEventType.ChartAnnotationCreate, ClientEventType.ChartAnnotationDelete], updateChartAnnotations);

const generateAnnotation = async (setting: ChartSetting, query: string) => {
// Note: Currently llm generated chart annotations are supported for the forecast chart only
if (!preparedChartInputs.value) return {};
const { reverseMap } = preparedChartInputs.value;
const variable = setting.selectedVariables[0];
const annotationLayerSpec = await generateForecastChartAnnotation(
query,
'timpoint_id',
[`${pyciemssMap.value[variable]}_mean:pre`, `${pyciemssMap.value[variable]}_mean`],
{
translationMap: reverseMap,
xAxisTitle: modelVarUnits.value._time || 'Time',
yAxisTitle: modelVarUnits.value[variable] || ''
}
);
const saved = await saveAnnotation(annotationLayerSpec, props.node.id, setting.id);
return saved;
};
// ---

const pyciemssMap = ref<Record<string, string>>({});
const preparedChartInputs = computed(() => {
Expand Down Expand Up @@ -644,40 +666,45 @@ const preparedCharts = computed(() => {
const datasetTimeField = state.mapping.find((d) => d.modelVariable === 'timestamp')?.datasetVariable;

const charts = {};
selectedVariables.value.forEach((variable) => {
selectedVariableSettings.value.forEach((settings) => {
const variable = settings.selectedVariables[0];
const annotations = chartAnnotations.value.filter((annotation) => annotation.chartId === settings.id);
const datasetVariables: string[] = [];
const mapObj = state.mapping.find((d) => d.modelVariable === variable);
if (mapObj) {
datasetVariables.push(mapObj.datasetVariable);
}
charts[variable] = createForecastChart(
{
data: result,
variables: [`${pyciemssMap.value[variable]}:pre`, pyciemssMap.value[variable]],
timeField: 'timepoint_id',
groupField: 'sample_id'
},
{
data: resultSummary,
variables: [`${pyciemssMap.value[variable]}_mean:pre`, `${pyciemssMap.value[variable]}_mean`],
timeField: 'timepoint_id'
},
{
data: groundTruthData.value,
variables: datasetVariables,
timeField: datasetTimeField as string,
groupField: 'sample_id'
},
{
title: variable,
width: chartSize.value.width,
height: chartSize.value.height,
legend: true,
translationMap: reverseMap,
xAxisTitle: modelVarUnits.value._time || 'Time',
yAxisTitle: modelVarUnits.value[variable] || '',
colorscheme: ['#AAB3C6', '#1B8073']
}
charts[variable] = applyForecastChartAnnotations(
createForecastChart(
{
data: result,
variables: [`${pyciemssMap.value[variable]}:pre`, pyciemssMap.value[variable]],
timeField: 'timepoint_id',
groupField: 'sample_id'
},
{
data: resultSummary,
variables: [`${pyciemssMap.value[variable]}_mean:pre`, `${pyciemssMap.value[variable]}_mean`],
timeField: 'timepoint_id'
},
{
data: groundTruthData.value,
variables: datasetVariables,
timeField: datasetTimeField as string,
groupField: 'sample_id'
},
{
title: variable,
width: chartSize.value.width,
height: chartSize.value.height,
legend: true,
translationMap: reverseMap,
xAxisTitle: modelVarUnits.value._time || 'Time',
yAxisTitle: modelVarUnits.value[variable] || '',
colorscheme: ['#AAB3C6', '#1B8073']
}
),
annotations
);
});
return charts;
Expand All @@ -689,7 +716,8 @@ const preparedDistributionCharts = computed(() => {
const labelBefore = 'Before calibration';
const labelAfter = 'After calibration';
const charts = {};
selectedParameters.value.forEach((param) => {
selectedParameterSettings.value.forEach((setting) => {
const param = setting.selectedVariables[0];
const fieldName = pyciemssMap.value[param];
const beforeFieldName = `${fieldName}:pre`;
const histogram = createHistogramChart(result, {
Expand Down Expand Up @@ -718,13 +746,15 @@ const preparedDistributionCharts = computed(() => {
});

const errorChartVariables = computed(() => {
if (!selectedErrorVariables.value.length) return [];
if (!selectedErrorVariableSettings.value.length) return [];
const getDatasetVariable = (modelVariable: string) =>
mapping.value.find((d) => d.modelVariable === modelVariable)?.datasetVariable;
const variables = selectedErrorVariables.value.map((variable) => ({
field: getDatasetVariable(variable) as string,
label: variable
}));
const variables = selectedErrorVariableSettings.value
.map((s) => s.selectedVariables[0])
.map((variable) => ({
field: getDatasetVariable(variable) as string,
label: variable
}));
return variables;
});

Expand Down
Loading