Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

4262 show loss chart in calibrate node #4269

Merged
merged 15 commits into from
Jul 25, 2024
Original file line number Diff line number Diff line change
@@ -1,13 +1,20 @@
<template>
<main>
<template v-if="!inProgressCalibrationId && runResult && csvAsset && runResultPre">
<template
v-if="
!inProgressCalibrationId && runResult && csvAsset && runResultPre && props.node.state.chartConfigs[0]?.length
"
>
<vega-chart
v-for="(_config, index) of props.node.state.chartConfigs"
:key="index"
:are-embed-actions-visible="false"
:visualization-spec="preparedCharts[index]"
/>
</template>
<template v-else>
<div ref="drilldownLossPlot" class="loss-chart" />
</template>
asylves1 marked this conversation as resolved.
Show resolved Hide resolved

<tera-progress-spinner v-if="inProgressCalibrationId" :font-size="2" is-centered style="height: 100%" />

Expand All @@ -21,26 +28,28 @@
<script setup lang="ts">
import _ from 'lodash';
import { csvParse, autoType } from 'd3';
import { computed, watch, ref, shallowRef } from 'vue';
import { computed, watch, ref, shallowRef, onMounted } from 'vue';
import Button from 'primevue/button';
import TeraOperatorPlaceholder from '@/components/operator/tera-operator-placeholder.vue';
import TeraProgressSpinner from '@/components/widgets/tera-progress-spinner.vue';
import {
getRunResultCSV,
pollAction,
subscribeToUpdateMessages,
unsubscribeToUpdateMessages,
makeForecastJobCiemss,
getSimulation,
parsePyCiemssMap,
DataArray
} from '@/services/models/simulation-service';
import { getModelConfigurationById, createModelConfiguration } from '@/services/model-configurations';
import { getModelByModelConfigurationId } from '@/services/model';
import { setupDatasetInput } from '@/services/calibrate-workflow';
import { renderLossGraph, setupDatasetInput } from '@/services/calibrate-workflow';
import { nodeMetadata, nodeOutputLabel } from '@/components/workflow/util';
import { logger } from '@/utils/logger';
import { Poller, PollerState } from '@/api/api';
import type { WorkflowNode } from '@/types/workflow';
import type { CsvAsset, SimulationRequest, Model, ModelConfiguration } from '@/types/Types';
import { ClientEvent, ClientEventType, CsvAsset, SimulationRequest, Model, ModelConfiguration } from '@/types/Types';
import { createLLMSummary } from '@/services/summary-service';
import { createForecastChart } from '@/services/charts';
import VegaChart from '@/components/widgets/VegaChart.vue';
Expand All @@ -59,12 +68,65 @@ const runResult = ref<DataArray>([]);
const runResultPre = ref<DataArray>([]);
const runResultSummary = ref<DataArray>([]);
const runResultSummaryPre = ref<DataArray>([]);
const drilldownLossPlot = ref<HTMLElement>();

const csvAsset = shallowRef<CsvAsset | undefined>(undefined);

const areInputsFilled = computed(() => props.node.inputs[0].value && props.node.inputs[1].value);
const inProgressCalibrationId = computed(() => props.node.state.inProgressCalibrationId);

const selectedOutputId = ref<string>();
const selectedRunId = computed(() => props.node.outputs.find((o) => o.id === selectedOutputId.value)?.value?.[0]);
asylves1 marked this conversation as resolved.
Show resolved Hide resolved

let lossValues: { [key: string]: number }[] = [];

function drawLossGraph() {
if (drilldownLossPlot.value) {
renderLossGraph(drilldownLossPlot.value, lossValues, {
width: 200,
height: 120
});
}
}

async function updateLossChartWithSimulation() {
if (props.node.active) {
selectedOutputId.value = props.node.active;
const simulationObj = await getSimulation(selectedRunId.value);
if (simulationObj?.updates) {
lossValues = simulationObj?.updates.map((d, i) => ({
iter: i,
loss: d.data.loss
}));
drawLossGraph();
}
}
}

const messageHandler = (event: ClientEvent<any>) => {
lossValues.push({ iter: lossValues.length, loss: event.data.loss });
drawLossGraph();
};

watch(
() => props.node.state.inProgressCalibrationId,
(id) => {
if (id === '') {
asylves1 marked this conversation as resolved.
Show resolved Hide resolved
unsubscribeToUpdateMessages([id], ClientEventType.SimulationPyciemss, messageHandler);
} else {
subscribeToUpdateMessages([id], ClientEventType.SimulationPyciemss, messageHandler);
}
},
{ immediate: true }
);

watch(
() => props.node.state.inProgressPreForecastId,
() => updateLossChartWithSimulation()
);

onMounted(async () => updateLossChartWithSimulation());
asylves1 marked this conversation as resolved.
Show resolved Hide resolved

let pyciemssMap: Record<string, string> = {};

const preparedCharts = computed(() => {
Expand Down
Loading