From 746905a899e1643eaf964ebe9c2b694dbe115459 Mon Sep 17 00:00:00 2001 From: dvince Date: Wed, 2 Oct 2024 14:36:05 -0400 Subject: [PATCH 1/3] [FEAT]: Add AMR enrichment to 'Enrich metadata with AI assistant' #4856 Adding a new CompoundTask which can run successive secondary tasks after the first finishes. Using this to run AMR enrichment and model card generation. --- .../widgets/tera-asset-enrichment.vue | 4 +- .../tera-model-from-equations-drilldown.vue | 4 +- .../client/hmi-client/src/services/goLLM.ts | 27 ++ .../controller/gollm/GoLLMController.java | 252 +++++++++++++----- .../hmiserver/models/task/CompoundTask.java | 28 ++ .../hmiserver/service/tasks/TaskService.java | 27 ++ 6 files changed, 266 insertions(+), 76 deletions(-) create mode 100644 packages/server/src/main/java/software/uncharted/terarium/hmiserver/models/task/CompoundTask.java diff --git a/packages/client/hmi-client/src/components/widgets/tera-asset-enrichment.vue b/packages/client/hmi-client/src/components/widgets/tera-asset-enrichment.vue index 6932b14fd3..6ba5c81af3 100644 --- a/packages/client/hmi-client/src/components/widgets/tera-asset-enrichment.vue +++ b/packages/client/hmi-client/src/components/widgets/tera-asset-enrichment.vue @@ -51,7 +51,7 @@ import Button from 'primevue/button'; import RadioButton from 'primevue/radiobutton'; import { computed, ref, watch } from 'vue'; import { logger } from '@/utils/logger'; -import { modelCard } from '@/services/goLLM'; +import { enrichModelMetadata } from '@/services/goLLM'; import { useProjects } from '@/composables/project'; import TeraModal from '@/components/widgets/tera-modal.vue'; import { useClientEvent } from '@/composables/useClientEvent'; @@ -125,7 +125,7 @@ const confirm = async () => { const sendForEnrichment = async () => { // Build enrichment job ids list (profile asset, align model, etc...) if (props.assetId && props.assetType === AssetType.Model) { - await modelCard(props.assetId, selectedResourceId.value); + await enrichModelMetadata(props.assetId, selectedResourceId.value, true); } else if (props.assetType === AssetType.Dataset) { await profileDataset(props.assetId, selectedResourceId.value); } diff --git a/packages/client/hmi-client/src/components/workflow/ops/model-from-equations/tera-model-from-equations-drilldown.vue b/packages/client/hmi-client/src/components/workflow/ops/model-from-equations/tera-model-from-equations-drilldown.vue index 2524d25bec..28fba46d7b 100644 --- a/packages/client/hmi-client/src/components/workflow/ops/model-from-equations/tera-model-from-equations-drilldown.vue +++ b/packages/client/hmi-client/src/components/workflow/ops/model-from-equations/tera-model-from-equations-drilldown.vue @@ -168,7 +168,7 @@ import Textarea from 'primevue/textarea'; import TeraInputText from '@/components/widgets/tera-input-text.vue'; import TeraSaveAssetModal from '@/components/project/tera-save-asset-modal.vue'; import TeraModelDescription from '@/components/model/petrinet/tera-model-description.vue'; -import { modelCard } from '@/services/goLLM'; +import { enrichModelMetadata } from '@/services/goLLM'; import TeraSliderPanel from '@/components/widgets/tera-slider-panel.vue'; import TeraPdfEmbed from '@/components/widgets/tera-pdf-embed.vue'; @@ -384,7 +384,7 @@ function getEquationErrorLabel(equation) { // generates the model card and fetches the model when finished async function generateCard(modelId: string, docId: string) { isGeneratingCard.value = true; - await modelCard(modelId, docId); + await enrichModelMetadata(modelId, docId, true); isGeneratingCard.value = false; await fetchModel(); } diff --git a/packages/client/hmi-client/src/services/goLLM.ts b/packages/client/hmi-client/src/services/goLLM.ts index 836afe600f..077655ab5d 100644 --- a/packages/client/hmi-client/src/services/goLLM.ts +++ b/packages/client/hmi-client/src/services/goLLM.ts @@ -34,6 +34,33 @@ export async function modelCard(modelId: string, documentId?: string): Promise { + try { + const response = await API.get('/gollm/enrich-model-metadata', { + params: { + 'model-id': modelId, + 'document-id': documentId, + overwrite: overwrite + } + }); + + const taskId = response.data.id; + await handleTaskById(taskId, { + ondata(data, closeConnection) { + if (data?.status === TaskStatus.Failed) { + closeConnection(); + throw new FatalError('Task failed'); + } + if (data.status === TaskStatus.Success) { + closeConnection(); + } + } + }); + } catch (err) { + logger.error(err); + } +} + export async function configureModelFromDocument( documentId: string, modelId: string, diff --git a/packages/server/src/main/java/software/uncharted/terarium/hmiserver/controller/gollm/GoLLMController.java b/packages/server/src/main/java/software/uncharted/terarium/hmiserver/controller/gollm/GoLLMController.java index 1b8548516e..cfa1deec99 100644 --- a/packages/server/src/main/java/software/uncharted/terarium/hmiserver/controller/gollm/GoLLMController.java +++ b/packages/server/src/main/java/software/uncharted/terarium/hmiserver/controller/gollm/GoLLMController.java @@ -42,6 +42,7 @@ import software.uncharted.terarium.hmiserver.models.dataservice.document.DocumentAsset; import software.uncharted.terarium.hmiserver.models.dataservice.document.ExtractedDocumentPage; import software.uncharted.terarium.hmiserver.models.dataservice.model.Model; +import software.uncharted.terarium.hmiserver.models.task.CompoundTask; import software.uncharted.terarium.hmiserver.models.task.TaskRequest; import software.uncharted.terarium.hmiserver.models.task.TaskRequest.TaskType; import software.uncharted.terarium.hmiserver.models.task.TaskResponse; @@ -142,55 +143,30 @@ public ResponseEntity createModelCardTask( throw new ResponseStatusException(HttpStatus.NOT_FOUND, messages.get("model.not-found")); } - final ModelCardResponseHandler.Input input = new ModelCardResponseHandler.Input(); - input.setAmr(model.get().serializeWithoutTerariumFields(null, new String[] { "gollmCard" })); - // Grab the document - final DocumentAsset document; if (documentId != null) { - final Optional documentOpt = documentAssetService.getAsset(documentId, permission); - if (documentOpt.isEmpty()) { - log.warn(String.format("Document %s not found", documentId)); - throw new ResponseStatusException(HttpStatus.NOT_FOUND, messages.get("document.not-found")); - } - - document = documentOpt.get(); - - // make sure there is text in the document - if (document.getText() == null || document.getText().isEmpty()) { - log.warn(String.format("Document %s has no text to send", documentId)); - throw new ResponseStatusException(HttpStatus.NOT_FOUND, messages.get("document.extraction.not-done")); - } - - // check for input length - if (document.getText().length() > ModelCardResponseHandler.MAX_TEXT_SIZE) { - log.warn(String.format("Document %s text too long for GoLLM model card task", documentId)); - throw new ResponseStatusException(HttpStatus.BAD_REQUEST, messages.get("document.text-length-exceeded")); - } - - input.setResearchPaper(document.getText()); + throw new ResponseStatusException(HttpStatus.NOT_FOUND, messages.get("document.not-found")); } - // Create the task - final TaskRequest req = new TaskRequest(); - req.setType(TaskType.GOLLM); - req.setScript(ModelCardResponseHandler.NAME); - req.setUserId(currentUserService.get().getId()); + final Optional documentOpt = documentAssetService.getAsset(documentId, permission); + if (documentOpt.isEmpty()) { + log.warn(String.format("Document %s not found", documentId)); + throw new ResponseStatusException(HttpStatus.NOT_FOUND, messages.get("document.not-found")); + } - try { - req.setInput(objectMapper.writeValueAsBytes(input)); - } catch (final Exception e) { - log.error("Unable to serialize input", e); - throw new ResponseStatusException(HttpStatus.INTERNAL_SERVER_ERROR, messages.get("generic.io-error.write")); + // make sure there is text in the document + if (documentOpt.get().getText() == null || documentOpt.get().getText().isEmpty()) { + log.warn(String.format("Document %s has no text to send", documentId)); + throw new ResponseStatusException(HttpStatus.NOT_FOUND, messages.get("document.extraction.not-done")); } - req.setProjectId(projectId); + // check for input length + if (documentOpt.get().getText().length() > ModelCardResponseHandler.MAX_TEXT_SIZE) { + log.warn(String.format("Document %s text too long for GoLLM model card task", documentId)); + throw new ResponseStatusException(HttpStatus.BAD_REQUEST, messages.get("document.text-length-exceeded")); + } - final ModelCardResponseHandler.Properties props = new ModelCardResponseHandler.Properties(); - props.setProjectId(projectId); - props.setDocumentId(documentId); - props.setModelId(modelId); - req.setAdditionalProperties(props); + final TaskRequest req = getModelCardTask(documentOpt.get(), model.get(), projectId); final TaskResponse resp; try { @@ -744,6 +720,101 @@ public ResponseEntity createGenerateResponseTask( return ResponseEntity.ok().body(resp); } + /** + * This endpoint will dispatch a few GoLLM tasks to enrich model metadata, + * including enriching the AMR and creating the model card + * @param modelId UUID of the model to enrich + * @param documentId UUID of the document to use for enrichment + * @param mode TaskMode to run the task in (is this ASYNC?) + * @param projectId UUID of the project to associate the task with for permissions + * @param overwrite boolean to determine if the model should be overwritten + * @return TaskResponse with the task ID + */ + @GetMapping("/enrich-model-metadata") + @Secured(Roles.USER) + @Operation(summary = "Dispatch a multiple GoLLM tasks to enrich model metadata") + @ApiResponses( + value = { + @ApiResponse( + responseCode = "200", + description = "Dispatched successfully", + content = @Content( + mediaType = "application/json", + schema = @io.swagger.v3.oas.annotations.media.Schema(implementation = TaskResponse.class) + ) + ), + @ApiResponse( + responseCode = "404", + description = "The provided model or document arguments are not found", + content = @Content + ), + @ApiResponse(responseCode = "500", description = "There was an issue dispatching the request", content = @Content) + } + ) + public ResponseEntity createEnrichModelMetadataTask( + @RequestParam(name = "model-id", required = true) final UUID modelId, + @RequestParam(name = "document-id", required = false) final UUID documentId, + @RequestParam(name = "mode", required = false, defaultValue = "ASYNC") final TaskMode mode, + @RequestParam(name = "project-id", required = false) final UUID projectId, + @RequestParam(name = "overwrite", required = false, defaultValue = "false") final boolean overwrite + ) { + final Schema.Permission permission = projectService.checkPermissionCanRead( + currentUserService.get().getId(), + projectId + ); + + // Grab the document + final Optional document = documentAssetService.getAsset(documentId, permission); + + // make sure there is text in the document. We don't need a document but if we do have one it can't be empty + if (document.isPresent() && (document.get().getText() == null || document.get().getText().isEmpty())) { + log.warn(String.format("Document %s has no extracted text", documentId)); + throw new ResponseStatusException(HttpStatus.NOT_FOUND, messages.get("document.extraction.not-done")); + } + + // Grab the model + final Optional model = modelService.getAsset(modelId, permission); + if (model.isEmpty()) { + log.warn(String.format("Model %s not found", modelId)); + throw new ResponseStatusException(HttpStatus.NOT_FOUND, messages.get("model.not-found")); + } + + final TaskRequest req; + + if (document.isPresent()) { + final TaskRequest enrichAmrRequest = getEnrichAMRTaskRequest( + document.orElse(null), + model.get(), + projectId, + overwrite + ); + final TaskRequest modelCardRequest = getModelCardTask(document.orElse(null), model.get(), projectId); + + req = new CompoundTask(enrichAmrRequest, modelCardRequest); + } else { + req = getModelCardTask(null, model.get(), projectId); + } + + final TaskResponse resp; + try { + resp = taskService.runTask(mode, req); + } catch (final JsonProcessingException e) { + log.error("Unable to serialize input", e); + throw new ResponseStatusException(HttpStatus.INTERNAL_SERVER_ERROR, messages.get("task.gollm.json-processing")); + } catch (final TimeoutException e) { + log.warn("Timeout while waiting for task response", e); + throw new ResponseStatusException(HttpStatus.SERVICE_UNAVAILABLE, messages.get("task.gollm.timeout")); + } catch (final InterruptedException e) { + log.warn("Interrupted while waiting for task response", e); + throw new ResponseStatusException(HttpStatus.UNPROCESSABLE_ENTITY, messages.get("task.gollm.interrupted")); + } catch (final ExecutionException e) { + log.error("Error while waiting for task response", e); + throw new ResponseStatusException(HttpStatus.INTERNAL_SERVER_ERROR, messages.get("task.gollm.execution-failure")); + } + + return ResponseEntity.ok().body(resp); + } + @GetMapping("/enrich-amr") @Secured(Roles.USER) @Operation(summary = "Dispatch a `GoLLM Enrich AMR` task") @@ -797,41 +868,11 @@ public ResponseEntity createEnrichAMRTask( throw new ResponseStatusException(HttpStatus.NOT_FOUND, messages.get("model.not-found")); } - final EnrichAmrResponseHandler.Input input = new EnrichAmrResponseHandler.Input(); - input.setResearchPaper(document.get().getText()); // stripping the metadata from the model before its sent since it can cause // gollm to fail with massive inputs model.get().setMetadata(null); - try { - final String amr = objectMapper.writeValueAsString(model.get()); - input.setAmr(amr); - } catch (final JsonProcessingException e) { - log.error("Unable to serialize model card", e); - throw new ResponseStatusException(HttpStatus.INTERNAL_SERVER_ERROR, messages.get("task.gollm.json-processing")); - } - - // Create the task - final TaskRequest req = new TaskRequest(); - req.setType(TaskType.GOLLM); - req.setScript(EnrichAmrResponseHandler.NAME); - req.setUserId(currentUserService.get().getId()); - - try { - req.setInput(objectMapper.writeValueAsBytes(input)); - } catch (final Exception e) { - log.error("Unable to serialize input", e); - throw new ResponseStatusException(HttpStatus.INTERNAL_SERVER_ERROR, messages.get("generic.io-error.write")); - } - - req.setProjectId(projectId); - - final EnrichAmrResponseHandler.Properties props = new EnrichAmrResponseHandler.Properties(); - props.setProjectId(projectId); - props.setDocumentId(documentId); - props.setModelId(modelId); - props.setOverwrite(overwrite); - req.setAdditionalProperties(props); + TaskRequest req = getEnrichAMRTaskRequest(document.get(), model.get(), projectId, overwrite); final TaskResponse resp; try { @@ -1063,4 +1104,71 @@ public ResponseEntity cancelTask(@PathVariable("task-id") final UUID taskI taskService.cancelTask(taskId); return ResponseEntity.ok().build(); } + + private TaskRequest getModelCardTask(DocumentAsset document, Model model, UUID projectId) { + final ModelCardResponseHandler.Input input = new ModelCardResponseHandler.Input(); + input.setAmr(model.serializeWithoutTerariumFields(null, new String[] { "gollmCard" })); + + if (document != null) input.setResearchPaper(document.getText()); + + // Create the task + final TaskRequest req = new TaskRequest(); + req.setType(TaskType.GOLLM); + req.setScript(ModelCardResponseHandler.NAME); + req.setUserId(currentUserService.get().getId()); + + try { + req.setInput(objectMapper.writeValueAsBytes(input)); + } catch (final Exception e) { + log.error("Unable to serialize input", e); + throw new ResponseStatusException(HttpStatus.INTERNAL_SERVER_ERROR, messages.get("generic.io-error.write")); + } + + req.setProjectId(projectId); + + final ModelCardResponseHandler.Properties props = new ModelCardResponseHandler.Properties(); + props.setProjectId(projectId); + if (document != null) props.setDocumentId(document.getId()); + props.setModelId(model.getId()); + req.setAdditionalProperties(props); + + return req; + } + + private TaskRequest getEnrichAMRTaskRequest(DocumentAsset document, Model model, UUID projectId, Boolean overwrite) { + final EnrichAmrResponseHandler.Input input = new EnrichAmrResponseHandler.Input(); + if (document != null) input.setResearchPaper(document.getText()); + + try { + final String amr = objectMapper.writeValueAsString(model); + input.setAmr(amr); + } catch (final JsonProcessingException e) { + log.error("Unable to serialize model card", e); + throw new ResponseStatusException(HttpStatus.INTERNAL_SERVER_ERROR, messages.get("task.gollm.json-processing")); + } + + // Create the task + final TaskRequest req = new TaskRequest(); + req.setType(TaskType.GOLLM); + req.setScript(EnrichAmrResponseHandler.NAME); + req.setUserId(currentUserService.get().getId()); + + try { + req.setInput(objectMapper.writeValueAsBytes(input)); + } catch (final Exception e) { + log.error("Unable to serialize input", e); + throw new ResponseStatusException(HttpStatus.INTERNAL_SERVER_ERROR, messages.get("generic.io-error.write")); + } + + req.setProjectId(projectId); + + final EnrichAmrResponseHandler.Properties props = new EnrichAmrResponseHandler.Properties(); + props.setProjectId(projectId); + if (document != null) props.setDocumentId(document.getId()); + props.setModelId(model.getId()); + props.setOverwrite(overwrite); + req.setAdditionalProperties(props); + + return req; + } } diff --git a/packages/server/src/main/java/software/uncharted/terarium/hmiserver/models/task/CompoundTask.java b/packages/server/src/main/java/software/uncharted/terarium/hmiserver/models/task/CompoundTask.java new file mode 100644 index 0000000000..be561d2f36 --- /dev/null +++ b/packages/server/src/main/java/software/uncharted/terarium/hmiserver/models/task/CompoundTask.java @@ -0,0 +1,28 @@ +package software.uncharted.terarium.hmiserver.models.task; + +import java.util.List; +import lombok.Data; + +/** + * Represents a compound task that consists of a primary task and one or more secondary tasks. + */ +@Data +public class CompoundTask extends TaskRequest { + + /** + * Constructs a CompoundTask with a primary task and optional secondary tasks. + * + * @param primaryTask the primary task + * @param secondaryTasks the secondary tasks + */ + public CompoundTask(TaskRequest primaryTask, TaskRequest... secondaryTasks) { + this.primaryTask = primaryTask; + this.secondaryTasks = List.of(secondaryTasks); + } + + /** The primary task of the compound task. */ + private TaskRequest primaryTask; + + /** The list of secondary tasks of the compound task. */ + private List secondaryTasks; +} diff --git a/packages/server/src/main/java/software/uncharted/terarium/hmiserver/service/tasks/TaskService.java b/packages/server/src/main/java/software/uncharted/terarium/hmiserver/service/tasks/TaskService.java index 981a6e3de2..8abcca21a1 100644 --- a/packages/server/src/main/java/software/uncharted/terarium/hmiserver/service/tasks/TaskService.java +++ b/packages/server/src/main/java/software/uncharted/terarium/hmiserver/service/tasks/TaskService.java @@ -38,6 +38,7 @@ import software.uncharted.terarium.hmiserver.models.ClientEventType; import software.uncharted.terarium.hmiserver.models.notification.NotificationEvent; import software.uncharted.terarium.hmiserver.models.notification.NotificationGroup; +import software.uncharted.terarium.hmiserver.models.task.CompoundTask; import software.uncharted.terarium.hmiserver.models.task.TaskFuture; import software.uncharted.terarium.hmiserver.models.task.TaskRequest; import software.uncharted.terarium.hmiserver.models.task.TaskResponse; @@ -635,6 +636,10 @@ public TaskResponse runTaskSync(final TaskRequest req) public TaskResponse runTask(final TaskMode mode, final TaskRequest req) throws JsonProcessingException, TimeoutException, InterruptedException, ExecutionException { + if (req instanceof CompoundTask) { + return runTask(mode, (CompoundTask) req); + } + if (mode == TaskMode.SYNC) { return runTaskSync(req); } else if (mode == TaskMode.ASYNC) { @@ -643,4 +648,26 @@ public TaskResponse runTask(final TaskMode mode, final TaskRequest req) throw new IllegalArgumentException("Invalid task mode: " + mode); } } + + /** + * Runs a compound task, executing the primary task synchronously and the secondary tasks + * in the specified mode (synchronous or asynchronous). + * + * @param mode The mode in which to run the secondary tasks (SYNC or ASYNC). + * @param req The compound task containing the primary and secondary tasks. + * @return The response of the primary task. + * @throws JsonProcessingException If there is an error processing JSON. + * @throws TimeoutException If the task times out. + * @throws InterruptedException If the task is interrupted. + * @throws ExecutionException If there is an error during task execution. + */ + public TaskResponse runTask(final TaskMode mode, final CompoundTask req) + throws JsonProcessingException, TimeoutException, InterruptedException, ExecutionException { + TaskResponse response = runTask(TaskMode.SYNC, req.getPrimaryTask()); + + for (final TaskRequest secondaryTask : req.getSecondaryTasks()) { + runTask(mode, secondaryTask); + } + return response; + } } From 361a2fd8d2b9ff3d3670069a6712fa65ec1c1636 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" <41898282+github-actions[bot]@users.noreply.github.com> Date: Wed, 2 Oct 2024 18:40:19 +0000 Subject: [PATCH 2/3] chore: lint and format client codebase --- packages/client/hmi-client/src/services/goLLM.ts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/packages/client/hmi-client/src/services/goLLM.ts b/packages/client/hmi-client/src/services/goLLM.ts index 077655ab5d..173ac1f2e5 100644 --- a/packages/client/hmi-client/src/services/goLLM.ts +++ b/packages/client/hmi-client/src/services/goLLM.ts @@ -40,7 +40,7 @@ export async function enrichModelMetadata(modelId: string, documentId: string, o params: { 'model-id': modelId, 'document-id': documentId, - overwrite: overwrite + overwrite } }); From a261dea3e84b28d6e1f6a53061995dff38a03056 Mon Sep 17 00:00:00 2001 From: Derek Vince Date: Wed, 2 Oct 2024 15:50:21 -0400 Subject: [PATCH 3/3] Update packages/server/src/main/java/software/uncharted/terarium/hmiserver/controller/gollm/GoLLMController.java Co-authored-by: Yohann Paris --- .../terarium/hmiserver/controller/gollm/GoLLMController.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/packages/server/src/main/java/software/uncharted/terarium/hmiserver/controller/gollm/GoLLMController.java b/packages/server/src/main/java/software/uncharted/terarium/hmiserver/controller/gollm/GoLLMController.java index 6285d82878..ac8893a192 100644 --- a/packages/server/src/main/java/software/uncharted/terarium/hmiserver/controller/gollm/GoLLMController.java +++ b/packages/server/src/main/java/software/uncharted/terarium/hmiserver/controller/gollm/GoLLMController.java @@ -732,7 +732,7 @@ public ResponseEntity createGenerateResponseTask( */ @GetMapping("/enrich-model-metadata") @Secured(Roles.USER) - @Operation(summary = "Dispatch a multiple GoLLM tasks to enrich model metadata") + @Operation(summary = "Dispatch multiple GoLLM tasks to enrich model metadata") @ApiResponses( value = { @ApiResponse(