Skip to content

Commit

Permalink
Calibrate model config (#4266)
Browse files Browse the repository at this point in the history
  • Loading branch information
mwdchang authored Jul 25, 2024
1 parent 8d5d282 commit 1e00dbf
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ export const CalibrationOperationCiemss: Operation = {
isOptional: true
}
],
outputs: [{ type: 'simulationId' }],
outputs: [{ type: 'modelConfigId' }],
isRunnable: true,

action: async () => {},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,14 @@ import {
parsePyCiemssMap,
DataArray
} from '@/services/models/simulation-service';
import { getModelConfigurationById, createModelConfiguration } from '@/services/model-configurations';
import { getModelByModelConfigurationId } from '@/services/model';
import { setupDatasetInput } from '@/services/calibrate-workflow';
import { nodeMetadata, nodeOutputLabel } from '@/components/workflow/util';
import { logger } from '@/utils/logger';
import { Poller, PollerState } from '@/api/api';
import type { WorkflowNode } from '@/types/workflow';
import type { CsvAsset, SimulationRequest, Model } from '@/types/Types';
import type { CsvAsset, SimulationRequest, Model, ModelConfiguration } from '@/types/Types';
import { createLLMSummary } from '@/services/summary-service';
import { createForecastChart } from '@/services/charts';
import VegaChart from '@/components/widgets/VegaChart.vue';
Expand Down Expand Up @@ -273,11 +274,25 @@ watch(
`;
const summaryResponse = await createLLMSummary(prompt);
const portLabel = props.node.inputs[0].label;
const baseConfig = await getModelConfigurationById(modelConfigId.value as string);
const calibratedModelConfig: ModelConfiguration = {
name: `Calibrated: ${baseConfig.name}`,
description: `Calibrated: ${baseConfig.description}`,
simulationId: state.calibrationId,
modelId: baseConfig.modelId,
observableSemanticList: [],
parameterSemanticList: [],
initialSemanticList: []
};
const modelConfigResponse = await createModelConfiguration(calibratedModelConfig);
// const portLabel = props.node.inputs[0].label;
emit('append-output', {
type: 'calibrateSimulationId',
label: nodeOutputLabel(props.node, `${portLabel} Result`),
value: [state.calibrationId],
type: 'modelConfigId',
label: nodeOutputLabel(props.node, `Calibration Result`),
value: [modelConfigResponse.id],
state: {
calibrationId: state.calibrationId,
forecastId: state.forecastId,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,6 @@ export const SimulateCiemssOperation: Operation = {
documentationUrl: DOCUMENTATION_URL,
inputs: [
{ type: 'modelConfigId', label: 'Model configuration', acceptMultiple: false },
{
type: 'calibrateSimulationId',
label: 'Calibration',
acceptMultiple: false,
isOptional: true
},
{
type: 'policyInterventionId',
label: 'Interventions',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,9 @@
<Dropdown id="5" v-model="method" :options="ciemssMethodOptions" @update:model-value="updateState" />
</div>
</div>
<!-- FIXME: show sampled values ???
<div v-if="inferredParameters">Using inferred parameters from calibration: {{ inferredParameters[0] }}</div>
-->
</div>
</tera-drilldown-section>
</section>
Expand Down Expand Up @@ -185,6 +187,7 @@ import { VAceEditor } from 'vue3-ace-editor';
import { VAceEditorInstance } from 'vue3-ace-editor/types';
import { createForecastChart } from '@/services/charts';
import VegaChart from '@/components/widgets/VegaChart.vue';
import { getModelConfigurationById } from '@/services/model-configurations';
import { SimulateCiemssOperationState } from './simulate-ciemss-operation';
import TeraChartControl from '../../tera-chart-control.vue';
Expand All @@ -197,8 +200,7 @@ const modelVarUnits = ref<{ [key: string]: string }>({});
let editor: VAceEditorInstance['_editor'] | null;
const codeText = ref('');
const inferredParameters = computed(() => props.node.inputs[1].value);
const policyInterventionId = computed(() => props.node.inputs[2].value);
const policyInterventionId = computed(() => props.node.inputs[1].value);
const timespan = ref<TimeSpan>(props.node.state.currentTimespan);
const llmThoughts = ref<any[]>([]);
Expand Down Expand Up @@ -391,9 +393,11 @@ const makeForecastRequest = async () => {
engine: 'ciemss'
};
if (inferredParameters.value?.[0]) {
payload.extra.inferred_parameters = inferredParameters.value[0];
const modelConfig = await getModelConfigurationById(modelConfigId);
if (modelConfig.simulationId) {
payload.extra.inferred_parameters = modelConfig.simulationId;
}
if (policyInterventionId.value?.[0]) {
payload.policyInterventionId = policyInterventionId.value[0];
}
Expand Down

0 comments on commit 1e00dbf

Please sign in to comment.