Skip to content

Commit

Permalink
Ensemble Calibrate - Intermediate Results (#5401)
Browse files Browse the repository at this point in the history
  • Loading branch information
Tom-Szendrey authored Nov 6, 2024
1 parent e397bf3 commit 8b1929a
Show file tree
Hide file tree
Showing 4 changed files with 150 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ export interface CalibrateEnsembleCiemssOperationState extends BaseState {
inProgressForecastId: string;
calibrationId: string;
forecastRunId: string;
currentProgress: number;
}

export const CalibrateEnsembleCiemssOperation: Operation = {
Expand Down Expand Up @@ -47,7 +48,8 @@ export const CalibrateEnsembleCiemssOperation: Operation = {
inProgressCalibrationId: '',
inProgressForecastId: '',
calibrationId: '',
forecastRunId: ''
forecastRunId: '',
currentProgress: 0
};
return init;
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import { createForecastChart, AUTOSIZE } from '@/services/charts';
import { getSimulation } from '@/services/models/simulation-service';

export async function getLossValuesFromSimulation(calibrationId: string) {
if (!calibrationId) return [];
const simulationObj = await getSimulation(calibrationId);
if (simulationObj?.updates) {
const lossValues = simulationObj?.updates
.sort((a, b) => a.data.progress - b.data.progress)
.map((d, i) => ({
iter: i,
loss: d.data.loss
}));
return lossValues;
}
return [];
}

export const updateLossChartSpec = (data: string | Record<string, any>[], size: { width: number; height: number }) =>
createForecastChart(
null,
{
data: Array.isArray(data) ? data : { name: data },
variables: ['loss'],
timeField: 'iter'
},
null,
{
title: '',
width: size.width,
height: 100,
xAxisTitle: 'Solver iterations',
yAxisTitle: 'Loss',
autosize: AUTOSIZE.FIT,
fitYDomain: true
}
);
Original file line number Diff line number Diff line change
Expand Up @@ -134,9 +134,18 @@
:options="outputs"
v-model:output="selectedOutputId"
@update:selection="onSelection"
:is-loading="showSpinner"
is-selectable
>
<!-- Loss chart -->
<div ref="lossChartContainer">
<vega-chart
v-if="!_.isEmpty(lossValues)"
expandable
ref="lossChartRef"
:are-embed-actions-visible="true"
:visualization-spec="lossChartSpec"
/>
</div>
<section v-if="!inProgressCalibrationId && !inProgressForecastId" ref="outputPanel">
<tera-simulate-chart
v-for="(cfg, index) of node.state.chartConfigs"
Expand All @@ -160,12 +169,15 @@
icon="pi pi-plus"
/>
</section>

<tera-progress-spinner
v-if="inProgressCalibrationId || inProgressForecastId"
:font-size="2"
is-centered
style="height: 100%"
/>
>
{{ node.state.currentProgress }}%
</tera-progress-spinner>
</tera-drilldown-preview>
</template>
</tera-drilldown>
Expand All @@ -178,8 +190,14 @@

<script setup lang="ts">
import _ from 'lodash';
import * as vega from 'vega';
import { ref, shallowRef, computed, watch, onMounted } from 'vue';
import { getRunResultCiemss, makeEnsembleCiemssCalibration } from '@/services/models/simulation-service';
import {
getRunResultCiemss,
makeEnsembleCiemssCalibration,
unsubscribeToUpdateMessages,
subscribeToUpdateMessages
} from '@/services/models/simulation-service';
import Button from 'primevue/button';
import TeraInputNumber from '@/components/widgets/tera-input-number.vue';
import AccordionTab from 'primevue/accordiontab';
Expand All @@ -193,22 +211,26 @@ import TeraDrilldownSection from '@/components/drilldown/tera-drilldown-section.
import TeraDrilldownPreview from '@/components/drilldown/tera-drilldown-preview.vue';
import TeraSaveDatasetFromSimulation from '@/components/dataset/tera-save-dataset-from-simulation.vue';
import TeraPyciemssCancelButton from '@/components/pyciemss/tera-pyciemss-cancel-button.vue';
import { chartActionsProxy, drilldownChartSize, getTimespan, nodeMetadata } from '@/components/workflow/util';
import type {
CsvAsset,
EnsembleModelConfigs,
EnsembleCalibrationCiemssRequest,
ModelConfiguration,
Dataset
Dataset,
ClientEvent
} from '@/types/Types';
import { ClientEventType } from '@/types/Types';
import { RunResults } from '@/types/SimulateConfig';
import { WorkflowNode } from '@/types/workflow';
import { getDataset } from '@/services/dataset';
import { useDrilldownChartSize } from '@/composables/useDrilldownChartSize';
import VegaChart from '@/components/widgets/VegaChart.vue';
import {
CalibrateEnsembleCiemssOperationState,
EnsembleCalibrateExtraCiemss
} from './calibrate-ensemble-ciemss-operation';
import { updateLossChartSpec, getLossValuesFromSimulation } from './calibrate-ensemble-util';
const props = defineProps<{
node: WorkflowNode<CalibrateEnsembleCiemssOperationState>;
Expand Down Expand Up @@ -256,11 +278,17 @@ const inProgressForecastId = computed(() => props.node.state.inProgressForecastI
const datasetId = computed(() => props.node.inputs[0].value?.[0] as string | undefined);
const currentDatasetFileName = ref<string>();
const datasetColumnNames = ref<string[]>();
// Loss Chart:
const lossChartRef = ref<InstanceType<typeof VegaChart>>();
const lossChartSpec = ref();
const lossValues = ref<{ [key: string]: number }[]>([]);
const lossChartContainer = ref(null);
const lossChartSize = useDrilldownChartSize(lossChartContainer);
const LOSS_CHART_DATA_SOURCE = 'lossData';
// Model:
const listModelLabels = ref<string[]>([]);
const allModelConfigurations = ref<ModelConfiguration[]>([]);
// List of each observible + state for each model.
const allModelOptions = ref<any[][]>([]);
const allModelOptions = ref<any[][]>([]); // List of each observible + state for each model.
const newSolutionMappingKey = ref<string>('');
const runResults = ref<RunResults>({});
Expand Down Expand Up @@ -295,8 +323,19 @@ function addMapping() {
emit('update-state', state);
}
const messageHandler = (event: ClientEvent<any>) => {
if (!lossChartRef.value?.view) return;
const data = { iter: lossValues.value.length, loss: event.data.loss };
lossChartRef.value.view.change(LOSS_CHART_DATA_SOURCE, vega.changeset().insert(data)).resize().run();
lossValues.value.push(data);
};
const runEnsemble = async () => {
if (!datasetId.value || !currentDatasetFileName.value) return;
// Reset loss buffer
lossValues.value = [];
const datasetMapping: { [index: string]: string } = {};
datasetMapping[knobs.value.timestampColName] = 'timestamp';
// Each key used in the ensemble configs is a dataset column.
Expand Down Expand Up @@ -326,9 +365,9 @@ const runEnsemble = async () => {
const response = await makeEnsembleCiemssCalibration(calibratePayload, nodeMetadata(props.node));
if (response?.simulationId) {
const state = _.cloneDeep(props.node.state);
state.currentProgress = 0;
state.inProgressCalibrationId = response?.simulationId;
state.inProgressForecastId = '';
emit('update-state', state);
}
};
Expand Down Expand Up @@ -393,6 +432,8 @@ watch(
const state = props.node.state;
const output = await getRunResultCiemss(state.forecastRunId, 'result.csv');
runResults.value = output.runResults;
lossValues.value = await getLossValuesFromSimulation(props.node.state.calibrationId);
lossChartSpec.value = await updateLossChartSpec(lossValues.value, lossChartSize.value);
}
},
{ immediate: true }
Expand All @@ -409,6 +450,22 @@ watch(
},
{ deep: true }
);
watch(
[() => props.node.state.inProgressCalibrationId, lossChartSize],
([id, size]) => {
if (id === '') {
showSpinner.value = false;
lossChartSpec.value = updateLossChartSpec(lossValues.value, size);
unsubscribeToUpdateMessages([id], ClientEventType.SimulationPyciemss, messageHandler);
} else {
showSpinner.value = true;
lossChartSpec.value = updateLossChartSpec(LOSS_CHART_DATA_SOURCE, size);
subscribeToUpdateMessages([id], ClientEventType.SimulationPyciemss, messageHandler);
}
},
{ immediate: true }
);
</script>

<style scoped>
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
<template>
<main>
<template v-if="!inProgressCalibrationId && !inProgressForecastId && runResults && csvAsset">
<vega-chart
v-if="!_.isEmpty(lossValues)"
:are-embed-actions-visible="false"
:visualization-spec="lossChartSpec"
/>

<tera-simulate-chart
v-for="(config, index) of props.node.state.chartConfigs"
:key="index"
Expand All @@ -22,7 +28,9 @@
:font-size="2"
is-centered
style="height: 100%"
/>
>
{{ node.state.currentProgress }}%
</tera-progress-spinner>

<Button v-if="areInputsFilled" label="Edit" @click="emit('open-drilldown')" severity="secondary" outlined />
<tera-operator-placeholder v-else :node="node">
Expand All @@ -32,7 +40,7 @@
</template>

<script setup lang="ts">
import { computed, ref, shallowRef, watch } from 'vue';
import { computed, ref, shallowRef, watch, onMounted } from 'vue';
import _ from 'lodash';
import Button from 'primevue/button';
import TeraOperatorPlaceholder from '@/components/operator/tera-operator-placeholder.vue';
Expand All @@ -47,14 +55,15 @@ import {
import { setupCsvAsset } from '@/services/calibrate-workflow';
import { chartActionsProxy, nodeMetadata, nodeOutputLabel } from '@/components/workflow/util';
import { logger } from '@/utils/logger';
import { Poller, PollerState } from '@/api/api';
import type { WorkflowNode } from '@/types/workflow';
import { WorkflowPortStatus } from '@/types/workflow';
import type { CsvAsset, EnsembleSimulationCiemssRequest, Dataset } from '@/types/Types';
import type { CsvAsset, EnsembleSimulationCiemssRequest, Dataset, Simulation } from '@/types/Types';
import type { RunResults } from '@/types/SimulateConfig';
import { getDataset } from '@/services/dataset';
import VegaChart from '@/components/widgets/VegaChart.vue';
import type { CalibrateEnsembleCiemssOperationState } from './calibrate-ensemble-ciemss-operation';
import { updateLossChartSpec, getLossValuesFromSimulation } from './calibrate-ensemble-util';
const props = defineProps<{
node: WorkflowNode<CalibrateEnsembleCiemssOperationState>;
Expand All @@ -67,6 +76,9 @@ const csvAsset = shallowRef<CsvAsset | undefined>(undefined);
const areInputsFilled = computed(() => props.node.inputs[0].value && props.node.inputs[1].value);
const inProgressCalibrationId = computed(() => props.node.state.inProgressCalibrationId);
const inProgressForecastId = computed(() => props.node.state.inProgressForecastId);
const lossValues = ref<{ [key: string]: number }[]>([]);
const lossChartSpec = ref();
const lossChartSize = { width: 180, height: 120 };
const chartProxy = chartActionsProxy(props.node, (state: CalibrateEnsembleCiemssOperationState) => {
emit('update-state', state);
Expand All @@ -77,11 +89,31 @@ const pollResult = async (runId: string) => {
poller
.setInterval(3000)
.setThreshold(300)
.setPollAction(async () => pollAction(runId));
.setPollAction(async () => pollAction(runId))
.setProgressAction((data: Simulation) => {
if (data?.updates?.length) {
lossValues.value = data?.updates
.sort((a, b) => a.data.progress - b.data.progress)
.map((d, i) => ({
iter: i,
loss: d.data.loss
}));
lossChartSpec.value = updateLossChartSpec(lossValues.value, lossChartSize);
}
if (runId === props.node.state.inProgressCalibrationId && data.updates.length > 0) {
const checkpoint = _.last(data.updates);
if (checkpoint) {
const state = _.cloneDeep(props.node.state);
state.currentProgress = +((100 * checkpoint.data.progress) / state.extra.numIterations).toFixed(2);
emit('update-state', state);
}
}
});
const pollerResults = await poller.start();
if (pollerResults.state === PollerState.Cancelled) {
const state = _.cloneDeep(props.node.state);
state.currentProgress = 0;
state.inProgressForecastId = '';
state.inProgressCalibrationId = '';
emit('update-state', state);
Expand All @@ -97,6 +129,12 @@ const pollResult = async (runId: string) => {
return pollerResults;
};
// Init loss chart
onMounted(async () => {
lossValues.value = await getLossValuesFromSimulation(props.node.state.calibrationId);
lossChartSpec.value = await updateLossChartSpec(lossValues.value, lossChartSize);
});
watch(
() => props.node.state.inProgressCalibrationId,
async (id) => {
Expand Down Expand Up @@ -143,6 +181,7 @@ watch(
const state = _.cloneDeep(props.node.state);
state.chartConfigs = [[]];
state.currentProgress = 0;
state.inProgressForecastId = '';
state.forecastRunId = id;
emit('update-state', state);
Expand Down

0 comments on commit 8b1929a

Please sign in to comment.