Skip to content

Commit

Permalink
ranking interventions first pass and better simulation output naming (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
shawnyama authored Jan 10, 2025
1 parent 5204260 commit 8582a0b
Show file tree
Hide file tree
Showing 7 changed files with 112 additions and 52 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,18 +24,14 @@ export enum CompareValue {

export const blankCriteriaOfInterest = {
name: 'Criteria of interest',
configurations: [],
selectedConfiguration: null,
variables: [],
selectedConfigurationId: null,
selectedVariable: null,
rank: RankOption.MINIMUM,
timepoint: TimepointOption.LAST
};
export interface CriteriaOfInterestCard {
name: string;
configurations: string[];
selectedConfiguration: string | null;
variables: string[];
selectedConfigurationId: string | null;
selectedVariable: string | null;
rank: RankOption;
timepoint: TimepointOption;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,12 @@
label="All simulations are from the same model"
disabled
/> -->
<div class="mb-4" />
<tera-checkbox
class="mt-2 mb-4"
v-model="areSimulationsFromSameModel"
label="All simulations are from the same model"
disabled
/>
<template v-if="knobs.selectedCompareOption === CompareValue.IMPACT">
<label> Select simulation to use as a baseline (optional) </label>
<Dropdown
Expand All @@ -40,11 +45,9 @@
placeholder="Optional"
@change="generateChartData"
/>
<div class="mb-4" />
<label>Comparison tables</label>
<tera-checkbox v-model="isATESelected" label="Average treatment effect (ATE)" />
</template>
<div class="mb-4" />
<!-- Pascale asked me to omit this timepoint selector, but I'm keeping it here until we are certain it's not needed -->
<!--
<label class="mt-2">Timepoint column</label>
Expand All @@ -56,12 +59,14 @@
/>
<div class="mb-4" />
-->
<template v-if="knobs.selectedCompareOption === CompareValue.RANK">
<div class="flex flex-column gap-2" v-if="knobs.selectedCompareOption === CompareValue.RANK">
<label>Specify criteria of interest:</label>
<tera-criteria-of-interest-card
v-for="(card, i) in node.state.criteriaOfInterestCards"
:key="i"
:card="card"
:model-configurations="modelConfigurations"
:variables="commonHeaderNames"
@delete="deleteCriteria(i)"
@update="(e) => updateCriteria(e, i)"
/>
Expand All @@ -75,7 +80,7 @@
@click="addCriteria"
/>
</div>
</template>
</div>
</tera-drilldown-section>
</template>
</tera-slider-panel>
Expand All @@ -85,16 +90,22 @@
<div ref="outputPanel">
<Accordion multiple :active-index="activeIndices">
<AccordionTab header="Summary"> </AccordionTab>
<AccordionTab header="Variables">
<template v-for="setting of selectedVariableSettings" :key="setting.id">
<vega-chart
:visualization-spec="variableCharts[setting.id]"
:are-embed-actions-visible="false"
expandable
/>
</template>
</AccordionTab>
<AccordionTab header="Comparison table"> </AccordionTab>
<template v-if="knobs.selectedCompareOption === CompareValue.IMPACT">
<AccordionTab header="Variables">
<template v-for="setting of selectedVariableSettings" :key="setting.id">
<vega-chart
:visualization-spec="variableCharts[setting.id]"
:are-embed-actions-visible="false"
expandable
/>
</template>
</AccordionTab>
<AccordionTab header="Comparison table"> </AccordionTab>
</template>
<template v-else>
<AccordionTab header="Ranking results"> </AccordionTab>
<AccordionTab header="Ranking criteria"> </AccordionTab>
</template>
</Accordion>
</div>
</tera-drilldown-section>
Expand Down Expand Up @@ -175,8 +186,10 @@ import Button from 'primevue/button';
import Accordion from 'primevue/accordion';
import AccordionTab from 'primevue/accordiontab';
import Dropdown from 'primevue/dropdown';
import { Dataset } from '@/types/Types';
import { Dataset, InterventionPolicy, ModelConfiguration } from '@/types/Types';
import { getDataset, getRawContent } from '@/services/dataset';
import { getInterventionPolicyById } from '@/services/intervention-policy';
import { getModelConfigurationById } from '@/services/model-configurations';
import TeraCheckbox from '@/components/widgets/tera-checkbox.vue';
import RadioButton from 'primevue/radiobutton';
import { isEmpty, cloneDeep } from 'lodash';
Expand Down Expand Up @@ -205,10 +218,12 @@ const emit = defineEmits(['update-state', 'update-status', 'close']);
const compareOptions: { label: string; value: CompareValue }[] = [
{ label: 'Compare the impact of interventions', value: CompareValue.IMPACT },
{ label: 'Rank interventions based on multiple charts', value: CompareValue.RANK }
{ label: 'Rank interventions based on multiple criteria', value: CompareValue.RANK }
];
const datasets = ref<Dataset[]>([]);
const modelConfigurations = ref<ModelConfiguration[]>([]);
const interventionPolicies = ref<InterventionPolicy[]>([]);
const commonHeaderNames = ref<string[]>([]);
const timepointHeaderName = ref<string | null>(null);
Expand All @@ -224,6 +239,7 @@ const isOutputSettingsOpen = ref(true);
const activeIndices = ref([0, 1, 2]);
const isFetchingDatasets = ref(false);
const areSimulationsFromSameModel = ref(true);
const isATESelected = ref(false);
const onRun = () => {
Expand Down Expand Up @@ -290,22 +306,43 @@ const initialize = async () => {
const state = cloneDeep(props.node.state);
knobs.value = Object.assign(knobs.value, state);
const interventionPolicyIds: string[] = [];
const modelConfigurationIds: string[] = [];
const inputs = props.node.inputs;
const datasetInputs = inputs.filter(
(input) => input.type === 'datasetId' && input.status === WorkflowPortStatus.CONNECTED
);
const promises = datasetInputs.map((input) => getDataset(input.value![0]));
const datasetPromises = datasetInputs.map((input) => getDataset(input.value![0]));
isFetchingDatasets.value = true;
await Promise.all(promises).then((ds) => {
const filteredDatasets: Dataset[] = ds.filter((dataset) => dataset !== null);
datasets.value.push(...filteredDatasets);
await Promise.all(datasetPromises).then((ds) => {
ds.forEach((dataset) => {
if (!dataset) return;
datasets.value.push(dataset);
const modelConfigurationId = dataset.metadata?.simulationAttributes?.modelConfigurationId;
const interventionPolicyId = dataset.metadata?.simulationAttributes?.interventionPolicyId;
if (modelConfigurationId) modelConfigurationIds.push(modelConfigurationId);
if (interventionPolicyId) interventionPolicyIds.push(interventionPolicyId);
});
});
isFetchingDatasets.value = false;
if (!knobs.value.selectedDataset) knobs.value.selectedDataset = datasets.value[0]?.id ?? null;
generateChartData();
if (isEmpty(modelConfigurationIds)) return;
const modelConfigurationPromises = modelConfigurationIds.map((id) => getModelConfigurationById(id));
await Promise.all(modelConfigurationPromises).then((configs) => {
modelConfigurations.value = configs.filter((config) => config !== null);
});
if (isEmpty(interventionPolicyIds)) return;
const interventionPolicyPromises = interventionPolicyIds.map((id) => getInterventionPolicyById(id));
await Promise.all(interventionPolicyPromises).then((policies) => {
interventionPolicies.value = policies.filter((policy) => policy !== null);
});
};
// Following two funcs are util like
Expand Down Expand Up @@ -335,7 +372,6 @@ function findDuplicates(strings: string[]): string[] {
async function generateChartData() {
if (datasets.value.length <= 1) return;
console.log(datasets);
const rawContents = await Promise.all(
datasets.value.map((dataset) => {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,15 @@
/>
<Button class="ml-auto" text icon="pi pi-trash" @click="emit('delete')" />
</header>

<div>
For configuration
<Dropdown
:options="card.configurations"
:model-value="card.selectedConfiguration"
@update:model-value="emit('update', { selectedConfiguration: $event })"
placeholder="Select..."
class="madlib-dropdown"
:options="modelConfigurations"
:model-value="card.selectedConfigurationId"
@update:model-value="emit('update', { selectedConfigurationId: $event })"
option-label="name"
option-value="id"
/>
rank interventions based on the
<Dropdown
Expand All @@ -26,23 +26,20 @@
option-label="label"
option-value="value"
@update:model-value="emit('update', { rank: $event })"
class="madlib-dropdown"
/>
value of
<Dropdown
:options="card.variables"
placeholder="Select..."
:options="variables"
:model-value="card.selectedVariable"
@update:model-value="emit('update', { selectedVariable: $event })"
placeholder="Select..."
class="madlib-dropdown"
/>
at
<Dropdown
:options="timepointOptions"
:model-value="card.timepoint"
option-label="label"
option-value="value"
class="madlib-dropdown"
@update:model-value="emit('update', { timepoint: $event })"
/>
timepoint.
Expand All @@ -54,6 +51,7 @@
import TeraToggleableInput from '@/components/widgets/tera-toggleable-input.vue';
import Button from 'primevue/button';
import Dropdown from 'primevue/dropdown';
import { ModelConfiguration } from '@/types/Types';
import { CriteriaOfInterestCard, RankOption, TimepointOption } from './compare-datasets-operation';
const timepointOptions = [
Expand All @@ -69,6 +67,8 @@ const rankOptions = [
const emit = defineEmits(['delete', 'update']);
defineProps<{
card: CriteriaOfInterestCard;
modelConfigurations: ModelConfiguration[];
variables: string[];
}>();
</script>

Expand All @@ -81,9 +81,9 @@ defineProps<{
gap: var(--gap-2);
display: flex;
flex-direction: column;
margin-bottom: var(--gap-1);
}
.madlib-dropdown {
.p-dropdown {
height: 2rem;
margin-bottom: var(--gap-1);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -166,16 +166,23 @@ Provide a summary in 100 words or less.
const summaryResponse = await createLLMSummary(prompt);
const datasetName = `Forecast run ${runId}`;
const datasetName = nodeOutputLabel(props.node, interventionPolicy.value?.name ?? 'no intervention');
const projectId = useProjects().activeProject.value?.id ?? '';
const datasetResult = await createDatasetFromSimulationResult(projectId, runId, datasetName, false);
const datasetResult = await createDatasetFromSimulationResult(
projectId,
runId,
datasetName,
false,
modelConfiguration.value?.id,
interventionPolicy.value?.id
);
if (!datasetResult) {
logger.error('Error creating dataset from simulation result.');
return;
}
emit('append-output', {
type: SimulateCiemssOperation.outputs[0].type,
label: nodeOutputLabel(props.node, 'Dataset'),
label: datasetName,
value: [datasetResult.id],
state: {
currentTimespan: state.currentTimespan,
Expand Down
13 changes: 7 additions & 6 deletions packages/client/hmi-client/src/services/dataset.ts
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,6 @@ async function downloadRawFile(datasetId: string, filename: string, limit: numbe

if (!promise) {
const URL = `/datasets/${datasetId}/download-csv?filename=${filename}&limit=${limit}`;
console.log('URL', URL);
promise = API.get(URL)
.then((response) => response?.data ?? null)
.catch((error) => {
Expand Down Expand Up @@ -252,13 +251,15 @@ async function createDatasetFromSimulationResult(
projectId: string,
simulationId: string,
datasetName: string | null,
addToProject?: boolean
addToProject: boolean = true,
modelConfigurationId?: string,
interventionPolicyId?: string
): Promise<Dataset | null> {
if (addToProject === undefined) addToProject = true;
try {
const response: AxiosResponse<Dataset> = await API.post(
`/simulations/${simulationId}/create-result-as-dataset/${projectId}?dataset-name=${datasetName}&add-to-project=${addToProject}`
);
let URL = `/simulations/${simulationId}/create-result-as-dataset/${projectId}?dataset-name=${datasetName}&add-to-project=${addToProject}`;
if (modelConfigurationId) URL += `&model-configuration-id=${modelConfigurationId}`;
if (interventionPolicyId) URL += `&intervention-policy-id=${interventionPolicyId}`;
const response: AxiosResponse<Dataset> = await API.post(URL);
return response.data as Dataset;
} catch (error) {
logger.error(`/simulations/{id}/create-result-as-dataset/{projectId} not responding: ${error}`, {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -345,7 +345,9 @@ public ResponseEntity<Dataset> createFromSimulationResult(
@PathVariable("id") final UUID id,
@PathVariable("project-id") final UUID projectId,
@RequestParam("dataset-name") final String datasetName,
@RequestParam("add-to-project") final Boolean addToProject
@RequestParam("add-to-project") final Boolean addToProject,
@RequestParam(value = "model-configuration-id", required = false) final UUID modelConfigurationId,
@RequestParam(value = "intervention-policy-id", required = false) final UUID interventionPolicyId
) {
final Schema.Permission permission = projectService.checkPermissionCanWrite(
currentUserService.get().getId(),
Expand All @@ -364,6 +366,8 @@ public ResponseEntity<Dataset> createFromSimulationResult(
datasetName,
projectId,
addToProject,
modelConfigurationId,
interventionPolicyId,
permission
);
return ResponseEntity.status(HttpStatus.CREATED).body(dataset);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
package software.uncharted.terarium.hmiserver.service.data;

import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.node.ObjectNode;
import io.micrometer.observation.annotation.Observed;
import java.io.IOException;
import java.util.ArrayList;
Expand Down Expand Up @@ -93,6 +93,8 @@ public Dataset createDatasetFromSimulation(
final String datasetName,
final UUID projectId,
final boolean addToProject,
final UUID modelConfigurationId,
final UUID interventionPolicyId,
final Schema.Permission permission
) {
try {
Expand All @@ -109,11 +111,25 @@ public Dataset createDatasetFromSimulation(
Dataset dataset = datasetService.createAsset(new Dataset(), projectId, permission);
dataset.setName(datasetName);
dataset.setDescription(sim.get().getDescription());
dataset.setMetadata(objectMapper.convertValue(Map.of("simulationId", simId.toString()), JsonNode.class));
dataset.setFileNames(sim.get().getResultFiles());
dataset.setDataSourceDate(sim.get().getCompletedTime());
dataset.setColumns(new ArrayList<>());

// Set the metadata
ObjectNode metadata = objectMapper.createObjectNode();
ObjectNode simulationAttributes = objectMapper.createObjectNode();

metadata.put("simulationId", simId.toString());
if (modelConfigurationId != null) {
simulationAttributes.put("modelConfigurationId", modelConfigurationId.toString());
}
if (interventionPolicyId != null) {
simulationAttributes.put("interventionPolicyId", interventionPolicyId.toString());
}
metadata.set("simulationAttributes", simulationAttributes);

dataset.setMetadata(metadata);

// Attach the user to the dataset
if (sim.get().getUserId() != null) {
dataset.setUserId(sim.get().getUserId());
Expand Down

0 comments on commit 8582a0b

Please sign in to comment.