Skip to content

Commit

Permalink
add extractionService for equations-to-model to use MIRA (#5721)
Browse files Browse the repository at this point in the history
Co-authored-by: Cole Blanchard <cblanchard@Coles-MacBook-Pro.local>
Co-authored-by: Cole Blanchard <33158416+blanchco@users.noreply.github.com>
  • Loading branch information
3 people authored Dec 3, 2024
1 parent 789134d commit f695e95
Show file tree
Hide file tree
Showing 5 changed files with 88 additions and 33 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,11 @@
<span class="flex align-items-center">Specify which equations to use for this model.</span>
<section class="white-space-nowrap min-w-min">
<Button class="mr-1" label="Reset" severity="secondary" outlined />
<Button

<SplitButton
label="Run"
@click="onRun"
:disabled="isDocumentLoading || isEmpty(includedEquations)"
:loading="isModelLoading"
:model="runItems"
:disabled="isDocumentLoading || isEmpty(includedEquations) || isModelLoading"
/>
</section>
</nav>
Expand Down Expand Up @@ -198,6 +198,7 @@ import TeraDrilldownSection from '@/components/drilldown/tera-drilldown-section.
import TeraPdfEmbed from '@/components/widgets/tera-pdf-embed.vue';
import TeraTextEditor from '@/components/documents/tera-text-editor.vue';
import { logger } from '@/utils/logger';
import SplitButton from 'primevue/splitbutton';
import { ModelFromEquationsState, EquationBlock } from './model-from-equations-operation';
const emit = defineEmits(['close', 'update-state', 'append-output', 'update-output', 'select-output']);
Expand All @@ -207,6 +208,17 @@ const props = defineProps<{
const selectedOutputId = ref<string>('');
const runItems = [
{
label: 'SKEMA',
command: () => onRun('skema')
},
{
label: 'Mira',
command: () => onRun('mira')
}
];
const clonedState = ref<ModelFromEquationsState>({
equations: [],
text: '',
Expand Down Expand Up @@ -365,7 +377,7 @@ function onCheckBoxChange(equation) {
emit('update-state', state);
}
async function onRun() {
async function onRun(extractionService: 'mira' | 'skema' = 'skema') {
isOutputOpen.value = true;
isModelLoading.value = true;
const equationsText = clonedState.value.equations
Expand All @@ -382,7 +394,8 @@ async function onRun() {
equations: cleanedEquations,
documentId: document.value?.id,
workflowId: props.node.workflowId,
nodeId: props.node.id
nodeId: props.node.id,
extractionService
};
const modelId = await equationsToAMR(request);
// If there isn't a modelId returned at least show the cleaned equations
Expand Down Expand Up @@ -580,4 +593,17 @@ watch(
.p-panel:deep(.p-panel-footer) {
display: none;
}
:deep(.p-splitbutton .p-button:first-of-type) {
border-top-right-radius: 0;
border-bottom-right-radius: 0;
border-right: 0 none;
pointer-events: none;
}
:deep(.p-splitbutton .p-button:last-of-type) {
border-top-left-radius: 0;
border-bottom-left-radius: 0;
color: #fff;
}
</style>
1 change: 1 addition & 0 deletions packages/client/hmi-client/src/services/knowledge.ts
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ export interface EquationsToAMRRequest {
documentId?: DocumentAsset['id'];
workflowId?: Workflow['id'];
nodeId?: WorkflowNode<any>['id'];
extractionService?: 'mira' | 'skema';
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ import { useProjects } from '@/composables/project';
import {
CloneProjectStatusUpdate,
ExtractionStatusUpdate,
ModelEnrichmentStatusUpdate,
NotificationItem,
NotificationItemStatus
} from '@/types/common';
Expand Down Expand Up @@ -205,22 +206,21 @@ export const createNotificationEventHandlers = (notificationItems: Ref<Notificat
created.typeDisplayName = `${snakeToCapitalSentence(event.data.data.simulationType)} (${event.data.data.simulationEngine.toLowerCase()})`;
}
);
registerHandler<TaskResponse>(ClientEventType.KnowledgeEnrichmentModel, (event, created) => {
registerHandler<ModelEnrichmentStatusUpdate>(ClientEventType.KnowledgeEnrichmentModel, (event, created) => {
created.sourceName = 'Model Enrichment';

// Check if the event data contains a workflowId and nodeId
if (event.data.additionalProperties.workflowId && event.data.additionalProperties.nodeId) {
created.assetId = event.data.additionalProperties.workflowId as string;
if (event.data.data?.workflowId && event.data.data?.nodeId) {
created.assetId = event.data.data.workflowId as string;
created.pageType = AssetType.Workflow;
created.nodeId = event.data.additionalProperties.nodeId as string;
created.nodeId = event.data.data.nodeId as string;
getWorkflow(created.assetId, created.projectId).then((workflow) =>
Object.assign(created, { context: workflow?.name || '' })
);
}

// We display the model page from where the enrichment model was triggered
else {
created.assetId = event.data.additionalProperties.modelId as string;
created.assetId = event.data.data.modelId as string;
created.pageType = AssetType.Model;
created.typeDisplayName = 'Model Enrichment';
getModel(created.assetId, created.projectId).then((model) =>
Expand Down
2 changes: 2 additions & 0 deletions packages/client/hmi-client/src/types/common.ts
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,8 @@ interface Comparison {
conclusion: string;
}

export type ModelEnrichmentStatusUpdate = StatusUpdate<{ modelId: string; workflowId: string; nodeId: string }>;

export type ExtractionStatusUpdate = StatusUpdate<{ documentId: string }>;
export type CloneProjectStatusUpdate = StatusUpdate<{ projectId: string }>;
export interface NotificationItem extends NotificationItemStatus, AssetRoute {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@
import software.uncharted.terarium.hmiserver.service.notification.NotificationGroupInstance;
import software.uncharted.terarium.hmiserver.service.notification.NotificationService;
import software.uncharted.terarium.hmiserver.service.tasks.EquationsCleanupResponseHandler;
import software.uncharted.terarium.hmiserver.service.tasks.LatexToAMRResponseHandler;
import software.uncharted.terarium.hmiserver.service.tasks.TaskService;
import software.uncharted.terarium.hmiserver.service.tasks.TaskService.TaskMode;
import software.uncharted.terarium.hmiserver.utils.ByteMultipartFile;
Expand Down Expand Up @@ -260,29 +261,54 @@ public ResponseEntity<UUID> equationsToModel(
}
}

// Create a request for SKEMA with the cleaned-up equations.
final JsonNode skemaRequest = mapper.createObjectNode().put("model", "petrinet").set("equations", equationsReq);

// Get an AMR from Skema Unified Service
// Get an AMR from Skema Unified Service or MIRA
final Model responseAMR;
try {
responseAMR = skemaUnifiedProxy.consolidatedEquationsToAMR(skemaRequest).getBody();
if (responseAMR == null) {
log.warn("Skema Unified Service did not return a valid AMR based on the provided equations");
throw new ResponseStatusException(HttpStatus.UNPROCESSABLE_ENTITY, messages.get("skema.bad-equations"));
String extractionService = "mira";
if (req.get("extractionService") != null) {
extractionService = req.get("extractionService").asText();
}

if (extractionService.equals("mira")) {
final TaskRequest taskReq = new TaskRequest();
final String latex = req.get("equations").toString();
taskReq.setType(TaskType.MIRA);
try {
taskReq.setInput(latex.getBytes());
taskReq.setScript(LatexToAMRResponseHandler.NAME);
taskReq.setUserId(currentUserService.get().getId());
final TaskResponse taskResp = taskService.runTaskSync(taskReq);
final JsonNode taskResponseJSON = mapper.readValue(taskResp.getOutput(), JsonNode.class);
final String amrString = taskResponseJSON.get("response").asText();
ObjectNode objNode = (ObjectNode) mapper.readTree(amrString);

final JsonNode testNode = mapper.readValue(amrString, JsonNode.class);
responseAMR = mapper.convertValue(testNode, Model.class);
} catch (Exception e) {
log.error("failed to convert latex equations to AMR", e);
throw new ResponseStatusException(HttpStatus.INTERNAL_SERVER_ERROR, "failed to convert latex equations to AMR");
}
} else {
try {
// Create a request for SKEMA with the cleaned-up equations.
final JsonNode skemaRequest = mapper.createObjectNode().put("model", "petrinet").set("equations", equationsReq);
responseAMR = skemaUnifiedProxy.consolidatedEquationsToAMR(skemaRequest).getBody();
if (responseAMR == null) {
log.warn("Skema Unified Service did not return a valid AMR based on the provided equations");
throw new ResponseStatusException(HttpStatus.UNPROCESSABLE_ENTITY, messages.get("skema.bad-equations"));
}
} catch (final FeignException e) {
log.error(
"An exception occurred while Skema Unified Service was trying to produce an AMR based on the provided equations",
e
);
throw handleSkemaFeignException(e);
} catch (final Exception e) {
log.error(
"An unhandled error occurred while Skema Unified Service was trying to produce an AMR based on the provided equations",
e
);
throw new ResponseStatusException(HttpStatus.INTERNAL_SERVER_ERROR, messages.get("skema.internal-error"));
}
} catch (final FeignException e) {
log.error(
"An exception occurred while Skema Unified Service was trying to produce an AMR based on the provided equations",
e
);
throw handleSkemaFeignException(e);
} catch (final Exception e) {
log.error(
"An unhandled error occurred while Skema Unified Service was trying to produce an AMR based on the provided equations",
e
);
throw new ResponseStatusException(HttpStatus.INTERNAL_SERVER_ERROR, messages.get("skema.internal-error"));
}

// We only handle Petri Net models
Expand Down

0 comments on commit f695e95

Please sign in to comment.