Skip to content

Commit

Permalink
4991 task update gollm intervention policy schema to align with new i…
Browse files Browse the repository at this point in the history
…ntervention policy model (#5010)
  • Loading branch information
dgauldie authored Oct 2, 2024
1 parent 2e0127f commit 083541f
Show file tree
Hide file tree
Showing 16 changed files with 214 additions and 72 deletions.
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import { InterventionPolicy } from '@/types/Types';
import type { BaseState, Operation } from '@/types/workflow';
import { WorkflowOperationTypes } from '@/types/workflow';
import type { Operation, BaseState } from '@/types/workflow';
import { isEqual, omit } from 'lodash';

export interface InterventionPolicyState extends BaseState {
interventionPolicy: InterventionPolicy;
taskIds: string[];
}

export const InterventionPolicyOperation: Operation = {
Expand All @@ -24,7 +25,8 @@ export const InterventionPolicyOperation: Operation = {
interventionPolicy: {
modelId: '',
interventions: []
}
},
taskIds: []
};
return init;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,17 @@
<template #content>
<section>
<nav class="inline-flex">
<!-- Disabled until Backend code is complete -->
<!-- <Button class="flex-1 mr-1" outlined severity="secondary" label="Extract from inputs" /> -->
<Button
class="flex-1 ml-1"
label="Create New"
:disabled="!model?.id"
@click="createNewInterventionPolicy"
class="flex-1 mr-1"
outlined
severity="secondary"
label="Extract from inputs"
icon="pi pi-sparkles"
:loading="isLoading"
:disabled="!props.node.inputs[0]?.value && !props.node.inputs[1]?.value"
@click="extractInterventionPolicyFromInputs"
/>
<Button class="ml-1" label="Create New" :disabled="!model?.id" @click="createNewInterventionPolicy" />
</nav>
<tera-input-text v-model="filterInterventionsText" placeholder="Filter" />
<ul v-if="!isFetchingPolicies">
Expand Down Expand Up @@ -161,7 +164,7 @@

<script setup lang="ts">
import _, { cloneDeep, groupBy, isEmpty, omit } from 'lodash';
import { computed, onMounted, ref, watch, nextTick, ComponentPublicInstance } from 'vue';
import { ComponentPublicInstance, computed, nextTick, onMounted, ref, watch } from 'vue';
import TeraDrilldown from '@/components/drilldown/tera-drilldown.vue';
import TeraDrilldownSection from '@/components/drilldown/tera-drilldown-section.vue';
import { WorkflowNode } from '@/types/workflow';
Expand All @@ -170,17 +173,17 @@ import TeraColumnarPanel from '@/components/widgets/tera-columnar-panel.vue';
import Button from 'primevue/button';
import TeraInputText from '@/components/widgets/tera-input-text.vue';
import { getInterventionPoliciesForModel, getModel } from '@/services/model';
import { Intervention, InterventionPolicy, Model, AssetType } from '@/types/Types';
import { AssetType, Intervention, InterventionPolicy, Model, type TaskResponse } from '@/types/Types';
import { logger } from '@/utils/logger';
import TeraProgressSpinner from '@/components/widgets/tera-progress-spinner.vue';
import { useConfirm } from 'primevue/useconfirm';
import { getParameters, getStates } from '@/model-representation/service';
import TeraToggleableInput from '@/components/widgets/tera-toggleable-input.vue';
import {
getInterventionPolicyById,
updateInterventionPolicy,
blankIntervention,
flattenInterventionData
flattenInterventionData,
getInterventionPolicyById,
updateInterventionPolicy
} from '@/services/intervention-policy';
import Accordion from 'primevue/accordion';
import AccordionTab from 'primevue/accordiontab';
Expand All @@ -192,12 +195,13 @@ import { createInterventionChart } from '@/services/charts';
import VegaChart from '@/components/widgets/VegaChart.vue';
import TeraSaveAssetModal from '@/components/project/tera-save-asset-modal.vue';
import { useProjects } from '@/composables/project';
import { interventionPolicyFromDocument } from '@/services/goLLM';
import TeraInterventionCard from './tera-intervention-card.vue';
import {
InterventionPolicyOperation,
InterventionPolicyState,
isInterventionPoliciesValuesEqual,
isInterventionPoliciesEqual
isInterventionPoliciesEqual,
isInterventionPoliciesValuesEqual
} from './intervention-policy-operation';
import TeraInterventionPolicyCard from './tera-intervention-policy-card.vue';
Expand Down Expand Up @@ -234,6 +238,7 @@ const isSidebarOpen = ref(true);
const filterInterventionsText = ref('');
const model = ref<Model | null>(null);
const isFetchingPolicies = ref(false);
const isLoading = ref(false);
const interventionsPolicyList = ref<InterventionPolicy[]>([]);
const interventionPoliciesFiltered = computed(() =>
interventionsPolicyList.value
Expand Down Expand Up @@ -269,6 +274,13 @@ const isSaveDisabled = computed(() => {
return hasSelectedPolicy && (isPolicyIdDifferent || arePoliciesEqual || !arePolicyValuesEqual);
});
const documentIds = computed(() =>
props.node.inputs
.filter((input) => input.type === 'documentId' && input.status === 'connected')
.map((input) => input.value?.[0]?.documentId)
.filter((id): id is string => id !== undefined)
);
const parameterOptions = computed(() => {
if (!model.value) return [];
return getParameters(model.value).map((parameter) => ({
Expand Down Expand Up @@ -473,6 +485,29 @@ const createNewInterventionPolicy = () => {
showSaveModal.value = true;
};
const extractInterventionPolicyFromInputs = async () => {
const state = cloneDeep(props.node.state);
if (!model.value?.id) {
return;
}
if (documentIds.value) {
const promiseList = [] as Promise<TaskResponse | null>[];
documentIds.value.forEach((documentId) => {
promiseList.push(
interventionPolicyFromDocument(documentId, model.value?.id as string, props.node.workflowId, props.node.id)
);
});
const responsesRaw = await Promise.all(promiseList);
responsesRaw.forEach((resp) => {
if (resp) {
state.taskIds.push(resp.id);
}
});
}
emit('update-state', state);
};
watch(
() => knobs.value,
async () => {
Expand All @@ -493,6 +528,20 @@ watch(
}
);
watch(
() => props.node.state.taskIds,
async (watchVal) => {
if (watchVal.length > 0) {
isLoading.value = true;
} else {
isLoading.value = false;
const modelId = props.node.inputs[0].value?.[0];
if (!modelId) return;
await fetchInterventionPolicies(modelId);
}
}
);
onMounted(() => {
if (props.node.active) {
selectedOutputId.value = props.node.active;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
</li>
</ul>
<tera-operator-placeholder :node="node" v-else />
<tera-progress-spinner is-centered :font-size="2" v-if="isLoading" />
<Button
:label="isModelInputConnected ? 'Open' : 'Attach a model'"
@click="emit('open-drilldown')"
Expand All @@ -17,21 +18,37 @@
</template>

<script setup lang="ts">
import { computed, watch } from 'vue';
import { computed, ref, watch } from 'vue';
import { WorkflowNode, WorkflowPortStatus } from '@/types/workflow';
import Button from 'primevue/button';
import TeraOperatorPlaceholder from '@/components/operator/tera-operator-placeholder.vue';
import _, { cloneDeep, groupBy } from 'lodash';
import { blankIntervention, flattenInterventionData } from '@/services/intervention-policy';
import { createInterventionChart } from '@/services/charts';
import VegaChart from '@/components/widgets/VegaChart.vue';
import { useClientEvent } from '@/composables/useClientEvent';
import { type ClientEvent, ClientEventType, type TaskResponse, TaskStatus } from '@/types/Types';
import TeraProgressSpinner from '@/components/widgets/tera-progress-spinner.vue';
import { InterventionPolicyState } from './intervention-policy-operation';
const emit = defineEmits(['open-drilldown', 'update-state']);
const props = defineProps<{
node: WorkflowNode<InterventionPolicyState>;
}>();
const taskIds = ref<string[]>([]);
const interventionEventHandler = async (event: ClientEvent<TaskResponse>) => {
if (!taskIds.value.includes(event.data?.id)) return;
if ([TaskStatus.Success, TaskStatus.Cancelled, TaskStatus.Failed].includes(event.data.status)) {
taskIds.value = taskIds.value.filter((id) => id !== event.data.id);
}
};
useClientEvent(ClientEventType.TaskGollmInterventionsFromDocument, interventionEventHandler);
const isLoading = computed(() => taskIds.value.length > 0);
const modelInput = props.node.inputs.find((input) => input.type === 'modelId');
const isModelInputConnected = computed(() => modelInput?.status === WorkflowPortStatus.CONNECTED);
Expand Down Expand Up @@ -69,6 +86,24 @@ watch(
},
{ deep: true }
);
watch(
() => props.node.state.taskIds,
() => {
taskIds.value = props.node.state.taskIds ?? [];
}
);
watch(
() => isLoading.value,
() => {
if (!isLoading.value) {
const state = cloneDeep(props.node.state);
state.taskIds = [];
emit('update-state', state);
}
}
);
</script>

<style scoped>
Expand Down
17 changes: 17 additions & 0 deletions packages/client/hmi-client/src/services/goLLM.ts
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,23 @@ export async function modelCard(modelId: string, documentId?: string): Promise<v
}
}

export async function interventionPolicyFromDocument(
documentId: string,
modelId: string,
workflowId?: string,
nodeId?: string
): Promise<TaskResponse> {
const { data } = await API.get<TaskResponse>('/gollm/interventions-from-document', {
params: {
'model-id': modelId,
'document-id': documentId,
'workflow-id': workflowId,
'node-id': nodeId
}
});
return data;
}

export async function configureModelFromDocument(
documentId: string,
modelId: string,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ export const createNotificationEventHandlers = (notificationItems: Ref<Notificat
});
registerHandler<TaskResponse>(ClientEventType.TaskGollmConfigureModelFromDocument, (event, created) => {
created.supportCancel = true;
created.sourceName = 'Configure model';
created.sourceName = 'Model Configuration from Document';
created.assetId = event.data.additionalProperties.workflowId as string;
created.pageType = AssetType.Workflow;
created.nodeId = event.data.additionalProperties.nodeId as string;
Expand All @@ -172,7 +172,7 @@ export const createNotificationEventHandlers = (notificationItems: Ref<Notificat
});
registerHandler<TaskResponse>(ClientEventType.TaskGollmConfigureModelFromDataset, (event, created) => {
created.supportCancel = true;
created.sourceName = 'Configure model';
created.sourceName = 'Model Configuration from Dataset';
created.assetId = event.data.additionalProperties.workflowId as string;
created.pageType = AssetType.Workflow;
created.nodeId = event.data.additionalProperties.nodeId as string;
Expand All @@ -190,6 +190,16 @@ export const createNotificationEventHandlers = (notificationItems: Ref<Notificat
Object.assign(created, { context: workflow?.name || '' })
);
});
registerHandler<TaskResponse>(ClientEventType.TaskGollmInterventionsFromDocument, (event, created) => {
created.supportCancel = true;
created.sourceName = 'Intervention Policies from Document';
created.assetId = event.data.additionalProperties.workflowId as string;
created.pageType = AssetType.Workflow;
created.nodeId = event.data.additionalProperties.nodeId as string;
getWorkflow(created.assetId, created.projectId).then((workflow) =>
Object.assign(created, { context: workflow?.name || '' })
);
});
registerHandler<StatusUpdate<SimulationNotificationData>>(
ClientEventType.SimulationNotification,
(event, created) => {
Expand Down
1 change: 1 addition & 0 deletions packages/client/hmi-client/src/types/Types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1076,6 +1076,7 @@ export enum ClientEventType {
TaskGollmEnrichAmr = "TASK_GOLLM_ENRICH_AMR",
TaskGollmEquationsFromImage = "TASK_GOLLM_EQUATIONS_FROM_IMAGE",
TaskGollmGenerateSummary = "TASK_GOLLM_GENERATE_SUMMARY",
TaskGollmInterventionsFromDocument = "TASK_GOLLM_INTERVENTIONS_FROM_DOCUMENT",
TaskGollmModelCard = "TASK_GOLLM_MODEL_CARD",
TaskMiraAmrToMmt = "TASK_MIRA_AMR_TO_MMT",
TaskMiraGenerateModelLatex = "TASK_MIRA_GENERATE_MODEL_LATEX",
Expand Down
8 changes: 8 additions & 0 deletions packages/gollm/gollm_openai/prompts/equations_from_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,13 @@
Do not respond in full sentences; only create a JSON object that satisfies the JSON schema specified in the response format.
LaTeX equations need to conform to a standard form. Use the following guidelines to ensure that your LaTeX equations are correctly formatted:
--- STYLE GUIDE ---
{style_guide}
--- STYLE GUIDE END ---
Answer:
"""
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
You are a helpful agent designed to find intervention policies for a given AMR model described in a research paper.
For context, intervention policies can include multiple interventions that include only static interventions or dynamic interventions.
Static interventions are applied at a specific point in time and permanently change the value of a specific parameter or state.
Dynamic interventions try to keep a specific parameter from going above or below a threshold value.
Dynamic interventions are applied when the value of a state crosses a threshold value.
Use the following AMR model JSON file as a reference:
Expand All @@ -26,9 +26,10 @@
3. `model_id` id a UUID. If the AMR model has an id, you can use it. Otherwise, you can set as the nil UUID "00000000-0000-0000-0000-000000000000".
4. For each intervention specified in the policy create an intervention object with the following rules.
a. Create a value for `name` from the user-provided text.
b. `appliedTo` should reference the id of the parameter or initial state of the AMR Model. If you cannot find an initial state or parameter that matches the intervention, do not create an intervention object.
c. `type` should be either "state" or "parameter" depending on what the intervention is applied to.
d. create a list of either static or dynamic interventions, but not both.
d. create a list of either static or dynamic interventions, but not both. The other list should be empty.
i. `appliedTo` should reference the id of a parameter or state in the AMR Model. If you cannot find a state or parameter in the AMR model that matches, do not create an intervention object.
ii. `type` should be either "state" or "parameter" depending on what the intervention is applied to.
iii. For dynamic interventions, `parameter` should be the id of a state that the threshold is applied to.
Do not respond in full sentences; only create a JSON object that satisfies the JSON schema specified in the response format.
Expand Down
18 changes: 18 additions & 0 deletions packages/gollm/gollm_openai/prompts/latex_style_guide.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
LATEXT_STYLE_GUIDE = """
1) Derivatives must be written in Leibniz notation
2) First-order derivative must be on the left of the equal sign
3) Use whitespace to indicate multiplication
a) "*" is optional but probably should be avoided
4) "(t)" is optional and probably should be avoided
5) Avoid superscripts and LaTeX superscripts "^", particularly to denote sub-variables
6) Subscripts using LaTeX "_" are permitted
a) Ensure that all characters used in the subscript are surrounded by a pair of curly brackets "{...}"
7) Avoid mathematical constants like pi or Euler's number
a) Replace them as floats with 3 decimal places of precision
8) Avoid parentheses
9) Avoid capital sigma and pi notations for summation and product
10) Avoid non-ASCII characters when possible
11) Avoid using homoglyphs
12) Avoid words or multi-character names for variables and names
a) Use camel case to express multi-word or multi-character names
"""
Loading

0 comments on commit 083541f

Please sign in to comment.