Skip to content

Commit

Permalink
Apply chart annotations to the forecast charts on calibrate operator (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
jryu01 authored Sep 19, 2024
1 parent ba1725c commit c577865
Show file tree
Hide file tree
Showing 10 changed files with 474 additions and 324 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ const props = defineProps<{
* @param setting ChartSetting
* @param query llm query to generate annotation
*/
generateAnnotation?: (setting: ChartSetting, query: string) => Promise<ChartAnnotation>;
generateAnnotation?: (setting: ChartSetting, query: string) => Promise<ChartAnnotation | null>;
}>();
const emit = defineEmits(['close', 'update:settings', 'delete-annotation', 'create-annotation']);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -217,31 +217,35 @@
<section ref="outputPanel" v-if="modelConfig && csvAsset">
<h5>Parameters</h5>
<br />
<template v-for="param of selectedParameters" :key="param">
<template v-for="setting of selectedParameterSettings" :key="setting.id">
<vega-chart
expandable
:are-embed-actions-visible="true"
:visualization-spec="preparedDistributionCharts[param].histogram"
:visualization-spec="preparedDistributionCharts[setting.selectedVariables[0]].histogram"
>
<template v-slot:footer>
<table class="distribution-table">
<thead>
<tr>
<th scope="col"></th>
<th scope="col">{{ preparedDistributionCharts[param].stat.header[0] }}</th>
<th scope="col">{{ preparedDistributionCharts[param].stat.header[1] }}</th>
<th scope="col">
{{ preparedDistributionCharts[setting.selectedVariables[0]].stat.header[0] }}
</th>
<th scope="col">
{{ preparedDistributionCharts[setting.selectedVariables[0]].stat.header[1] }}
</th>
</tr>
</thead>
<tbody>
<tr>
<th scope="row">Mean</th>
<td>{{ preparedDistributionCharts[param].stat.mean[0] }}</td>
<td>{{ preparedDistributionCharts[param].stat.mean[1] }}</td>
<td>{{ preparedDistributionCharts[setting.selectedVariables[0]].stat.mean[0] }}</td>
<td>{{ preparedDistributionCharts[setting.selectedVariables[0]].stat.mean[1] }}</td>
</tr>
<tr>
<th scope="row">Variance</th>
<td>{{ preparedDistributionCharts[param].stat.variance[0] }}</td>
<td>{{ preparedDistributionCharts[param].stat.variance[1] }}</td>
<td>{{ preparedDistributionCharts[setting.selectedVariables[0]].stat.variance[0] }}</td>
<td>{{ preparedDistributionCharts[setting.selectedVariables[0]].stat.variance[1] }}</td>
</tr>
</tbody>
</table>
Expand All @@ -250,12 +254,16 @@
</template>
<h5>Variables</h5>
<br />
<template v-for="variable of selectedVariables" :key="variable">
<vega-chart expandable :are-embed-actions-visible="true" :visualization-spec="preparedCharts[variable]" />
<template v-for="setting of selectedVariableSettings" :key="setting.id">
<vega-chart
expandable
:are-embed-actions-visible="true"
:visualization-spec="preparedCharts[setting.selectedVariables[0]]"
/>
</template>
<h5>Errors</h5>
<vega-chart
v-if="errorData.length > 0 && selectedErrorVariables.length > 0"
v-if="errorData.length > 0 && selectedErrorVariableSettings.length > 0"
:expandable="onExpandErrorChart"
:are-embed-actions-visible="true"
:visualization-spec="errorChart"
Expand Down Expand Up @@ -290,8 +298,7 @@
"
:active-settings="activeChartSettings"
:generate-annotation="generateAnnotation"
@create-annotation="addChartAnnotation"
@delete-annotation="removeChartAnnotation"
@delete-annotation="deleteAnnotation"
@close="activeChartSettings = null"
/>
</template>
Expand All @@ -308,7 +315,10 @@
@remove="removeChartSetting"
/>
<tera-chart-control
:chart-config="{ selectedRun: 'fixme', selectedVariable: selectedParameters }"
:chart-config="{
selectedRun: 'fixme',
selectedVariable: selectedParameterSettings.map((s) => s.selectedVariables[0])
}"
:multi-select="true"
:show-remove-button="false"
:variables="Object.keys(pyciemssMap).filter((c) => modelPartTypesMap[c] === 'parameter')"
Expand All @@ -326,7 +336,10 @@
@remove="removeChartSetting"
/>
<tera-chart-control
:chart-config="{ selectedRun: 'fixme', selectedVariable: selectedVariables }"
:chart-config="{
selectedRun: 'fixme',
selectedVariable: selectedVariableSettings.map((s) => s.selectedVariables[0])
}"
:multi-select="true"
:show-remove-button="false"
:variables="
Expand All @@ -346,7 +359,10 @@
@remove="removeChartSetting"
/>
<tera-chart-control
:chart-config="{ selectedRun: 'fixme', selectedVariable: selectedErrorVariables }"
:chart-config="{
selectedRun: 'fixme',
selectedVariable: selectedErrorVariableSettings.map((s) => s.selectedVariables[0])
}"
:multi-select="true"
:show-remove-button="false"
:variables="Object.keys(pyciemssMap).filter((c) => mapping.find((d) => d.modelVariable === c))"
Expand All @@ -371,7 +387,6 @@
<script setup lang="ts">
import _ from 'lodash';
import * as vega from 'vega';
import { v4 as uuidv4 } from 'uuid';
import { csvParse, autoType, mean, variance } from 'd3';
import { computed, onMounted, ref, shallowRef, watch } from 'vue';
import Button from 'primevue/button';
Expand All @@ -380,7 +395,14 @@ import Dropdown from 'primevue/dropdown';
import Column from 'primevue/column';
import TeraInputNumber from '@/components/widgets/tera-input-number.vue';
import { CalibrateMap, setupDatasetInput, setupModelInput } from '@/services/calibrate-workflow';
import { removeChartSettingById, updateChartSettingsBySelectedVariables } from '@/services/chart-settings';
import {
deleteAnnotation,
fetchAnnotations,
generateForecastChartAnnotation,
removeChartSettingById,
saveAnnotation,
updateChartSettingsBySelectedVariables
} from '@/services/chart-settings';
import TeraDrilldown from '@/components/drilldown/tera-drilldown.vue';
import TeraDrilldownSection from '@/components/drilldown/tera-drilldown-section.vue';
import TeraProgressSpinner from '@/components/widgets/tera-progress-spinner.vue';
Expand Down Expand Up @@ -421,6 +443,7 @@ import {
createForecastChart,
createHistogramChart,
createErrorChart,
applyForecastChartAnnotations,
createInterventionChartMarkers
} from '@/services/charts';
import VegaChart from '@/components/widgets/VegaChart.vue';
Expand All @@ -429,6 +452,7 @@ import TeraInputText from '@/components/widgets/tera-input-text.vue';
import { displayNumber } from '@/utils/number';
import TeraPyciemssCancelButton from '@/components/pyciemss/tera-pyciemss-cancel-button.vue';
import TeraSaveAssetModal from '@/components/project/tera-save-asset-modal.vue';
import { useClientEvent } from '@/composables/useClientEvent';
import { getInterventionPolicyById } from '@/services/intervention-policy';
import TeraInterventionSummaryCard from '@/components/workflow/ops/simulate-ciemss/tera-intervention-summary-card.vue';
import type { CalibrationOperationStateCiemss } from './calibrate-operation';
Expand Down Expand Up @@ -583,50 +607,44 @@ const outputPanel = ref(null);
const chartSize = computed(() => drilldownChartSize(outputPanel.value));
const chartSettings = computed(() => props.node.state.chartSettings ?? []);
const selectedParameters = computed(() =>
chartSettings.value
.filter((setting) => setting.type === ChartSettingType.DISTRIBUTION_COMPARISON)
.map((setting) => setting.selectedVariables[0])
const selectedParameterSettings = computed(() =>
chartSettings.value.filter((setting) => setting.type === ChartSettingType.DISTRIBUTION_COMPARISON)
);
const selectedVariables = computed(() =>
chartSettings.value
.filter((setting) => setting.type === ChartSettingType.VARIABLE_COMPARISON)
.map((setting) => setting.selectedVariables[0])
const selectedVariableSettings = computed(() =>
chartSettings.value.filter((setting) => setting.type === ChartSettingType.VARIABLE_COMPARISON)
);
const selectedErrorVariables = computed(() =>
chartSettings.value
.filter((setting) => setting.type === ChartSettingType.ERROR_DISTRIBUTION)
.map((setting) => setting.selectedVariables[0])
const selectedErrorVariableSettings = computed(() =>
chartSettings.value.filter((setting) => setting.type === ChartSettingType.ERROR_DISTRIBUTION)
);
// --- Handle chart annotations
const chartAnnotations = ref<ChartAnnotation[]>([]);
const generateAnnotation = async (setting: ChartSetting, query: string) => {
// Generate fake annotation. The annotation generation logic for the specific chart setting should go here
// Different chart settings type may have different annotation generation logic
await new Promise((resolve) => {
setTimeout(resolve, 1000);
});
const annotation: ChartAnnotation = {
id: uuidv4(),
description: query,
nodeId: props.node.id,
outputId: '',
chartId: setting.id,
layerSpec: {},
llmGenerated: false,
metadata: {}
};
return annotation;
const updateChartAnnotations = async () => {
chartAnnotations.value = await fetchAnnotations(props.node.id);
};
const addChartAnnotation = (annotation: ChartAnnotation) => {
chartAnnotations.value.push(annotation);
};
const removeChartAnnotation = (annotationId: string) => {
const index = chartAnnotations.value.findIndex((annotation) => annotation.id === annotationId);
if (index !== -1) {
chartAnnotations.value.splice(index, 1);
}
onMounted(() => updateChartAnnotations());
useClientEvent([ClientEventType.ChartAnnotationCreate, ClientEventType.ChartAnnotationDelete], updateChartAnnotations);
const generateAnnotation = async (setting: ChartSetting, query: string) => {
// Note: Currently llm generated chart annotations are supported for the forecast chart only
if (!preparedChartInputs.value) return null;
const { reverseMap } = preparedChartInputs.value;
const variable = setting.selectedVariables[0];
const annotationLayerSpec = await generateForecastChartAnnotation(
query,
'timpoint_id',
[`${pyciemssMap.value[variable]}_mean:pre`, `${pyciemssMap.value[variable]}_mean`],
{
translationMap: reverseMap,
xAxisTitle: modelVarUnits.value._time || 'Time',
yAxisTitle: modelVarUnits.value[variable] || ''
}
);
const saved = await saveAnnotation(annotationLayerSpec, props.node.id, setting.id);
return saved;
};
// ---
const pyciemssMap = ref<Record<string, string>>({});
const preparedChartInputs = computed(() => {
Expand Down Expand Up @@ -669,40 +687,45 @@ const preparedCharts = computed(() => {
const datasetTimeField = state.mapping.find((d) => d.modelVariable === 'timestamp')?.datasetVariable;
const charts = {};
selectedVariables.value.forEach((variable) => {
selectedVariableSettings.value.forEach((settings) => {
const variable = settings.selectedVariables[0];
const annotations = chartAnnotations.value.filter((annotation) => annotation.chartId === settings.id);
const datasetVariables: string[] = [];
const mapObj = state.mapping.find((d) => d.modelVariable === variable);
if (mapObj) {
datasetVariables.push(mapObj.datasetVariable);
}
charts[variable] = createForecastChart(
{
data: result,
variables: [`${pyciemssMap.value[variable]}:pre`, pyciemssMap.value[variable]],
timeField: 'timepoint_id',
groupField: 'sample_id'
},
{
data: resultSummary,
variables: [`${pyciemssMap.value[variable]}_mean:pre`, `${pyciemssMap.value[variable]}_mean`],
timeField: 'timepoint_id'
},
{
data: groundTruthData.value,
variables: datasetVariables,
timeField: datasetTimeField as string,
groupField: 'sample_id'
},
{
title: variable,
width: chartSize.value.width,
height: chartSize.value.height,
legend: true,
translationMap: reverseMap,
xAxisTitle: modelVarUnits.value._time || 'Time',
yAxisTitle: modelVarUnits.value[variable] || '',
colorscheme: ['#AAB3C6', '#1B8073']
}
charts[variable] = applyForecastChartAnnotations(
createForecastChart(
{
data: result,
variables: [`${pyciemssMap.value[variable]}:pre`, pyciemssMap.value[variable]],
timeField: 'timepoint_id',
groupField: 'sample_id'
},
{
data: resultSummary,
variables: [`${pyciemssMap.value[variable]}_mean:pre`, `${pyciemssMap.value[variable]}_mean`],
timeField: 'timepoint_id'
},
{
data: groundTruthData.value,
variables: datasetVariables,
timeField: datasetTimeField as string,
groupField: 'sample_id'
},
{
title: variable,
width: chartSize.value.width,
height: chartSize.value.height,
legend: true,
translationMap: reverseMap,
xAxisTitle: modelVarUnits.value._time || 'Time',
yAxisTitle: modelVarUnits.value[variable] || '',
colorscheme: ['#AAB3C6', '#1B8073']
}
),
annotations
);
charts[variable].layer.push(...createInterventionChartMarkers(groupedInterventionOutputs.value[variable]));
Expand All @@ -716,7 +739,8 @@ const preparedDistributionCharts = computed(() => {
const labelBefore = 'Before calibration';
const labelAfter = 'After calibration';
const charts = {};
selectedParameters.value.forEach((param) => {
selectedParameterSettings.value.forEach((setting) => {
const param = setting.selectedVariables[0];
const fieldName = pyciemssMap.value[param];
const beforeFieldName = `${fieldName}:pre`;
const histogram = createHistogramChart(result, {
Expand Down Expand Up @@ -745,13 +769,15 @@ const preparedDistributionCharts = computed(() => {
});
const errorChartVariables = computed(() => {
if (!selectedErrorVariables.value.length) return [];
if (!selectedErrorVariableSettings.value.length) return [];
const getDatasetVariable = (modelVariable: string) =>
mapping.value.find((d) => d.modelVariable === modelVariable)?.datasetVariable;
const variables = selectedErrorVariables.value.map((variable) => ({
field: getDatasetVariable(variable) as string,
label: variable
}));
const variables = selectedErrorVariableSettings.value
.map((s) => s.selectedVariables[0])
.map((variable) => ({
field: getDatasetVariable(variable) as string,
label: variable
}));
return variables;
});
Expand Down
Loading

0 comments on commit c577865

Please sign in to comment.