Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FEAT]: Add AMR enrichment to 'Enrich metadata with AI assistant' #5019

Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ import Button from 'primevue/button';
import RadioButton from 'primevue/radiobutton';
import { computed, ref, watch } from 'vue';
import { logger } from '@/utils/logger';
import { modelCard } from '@/services/goLLM';
import { enrichModelMetadata } from '@/services/goLLM';
import { useProjects } from '@/composables/project';
import TeraModal from '@/components/widgets/tera-modal.vue';
import { useClientEvent } from '@/composables/useClientEvent';
Expand Down Expand Up @@ -125,7 +125,7 @@ const confirm = async () => {
const sendForEnrichment = async () => {
// Build enrichment job ids list (profile asset, align model, etc...)
if (props.assetId && props.assetType === AssetType.Model) {
await modelCard(props.assetId, selectedResourceId.value);
await enrichModelMetadata(props.assetId, selectedResourceId.value, true);
} else if (props.assetType === AssetType.Dataset) {
await profileDataset(props.assetId, selectedResourceId.value);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ import Textarea from 'primevue/textarea';
import TeraInputText from '@/components/widgets/tera-input-text.vue';
import TeraSaveAssetModal from '@/components/project/tera-save-asset-modal.vue';
import TeraModelDescription from '@/components/model/petrinet/tera-model-description.vue';
import { modelCard } from '@/services/goLLM';
import { enrichModelMetadata } from '@/services/goLLM';
import TeraSliderPanel from '@/components/widgets/tera-slider-panel.vue';

import TeraPdfEmbed from '@/components/widgets/tera-pdf-embed.vue';
Expand Down Expand Up @@ -388,7 +388,7 @@ function getEquationErrorLabel(equation) {
// generates the model card and fetches the model when finished
async function generateCard(modelId: string, docId: string) {
isGeneratingCard.value = true;
await modelCard(modelId, docId);
await enrichModelMetadata(modelId, docId, true);
isGeneratingCard.value = false;
await fetchModel();
}
Expand Down
27 changes: 27 additions & 0 deletions packages/client/hmi-client/src/services/goLLM.ts
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,33 @@ export async function interventionPolicyFromDocument(
return data;
}

export async function enrichModelMetadata(modelId: string, documentId: string, overwrite: boolean): Promise<void> {
try {
const response = await API.get<TaskResponse>('/gollm/enrich-model-metadata', {
params: {
'model-id': modelId,
'document-id': documentId,
overwrite
}
});

const taskId = response.data.id;
await handleTaskById(taskId, {
ondata(data, closeConnection) {
if (data?.status === TaskStatus.Failed) {
closeConnection();
throw new FatalError('Task failed');
}
if (data.status === TaskStatus.Success) {
closeConnection();
}
}
});
Comment on lines +64 to +75
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can remove this code. Since, instead of creating connection for each task, we can subscribe/unsubscribe on global sse event channel whenever needed using subscribe in ClientEventService.ts or useClientEvent.ts

Copy link
Contributor

@jryu01 jryu01 Oct 2, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And I believe the endpoint for subscribing for each individual task event got removed in the server side at some point so this code won't work if I remember correctly.

} catch (err) {
logger.error(err);
}
}

export async function configureModelFromDocument(
documentId: string,
modelId: string,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
import software.uncharted.terarium.hmiserver.models.dataservice.document.DocumentAsset;
import software.uncharted.terarium.hmiserver.models.dataservice.document.ExtractedDocumentPage;
import software.uncharted.terarium.hmiserver.models.dataservice.model.Model;
import software.uncharted.terarium.hmiserver.models.task.CompoundTask;
import software.uncharted.terarium.hmiserver.models.task.TaskRequest;
import software.uncharted.terarium.hmiserver.models.task.TaskRequest.TaskType;
import software.uncharted.terarium.hmiserver.models.task.TaskResponse;
Expand Down Expand Up @@ -144,55 +145,30 @@ public ResponseEntity<TaskResponse> createModelCardTask(
throw new ResponseStatusException(HttpStatus.NOT_FOUND, messages.get("model.not-found"));
}

final ModelCardResponseHandler.Input input = new ModelCardResponseHandler.Input();
input.setAmr(model.get().serializeWithoutTerariumFields(null, new String[] { "gollmCard" }));

// Grab the document
final DocumentAsset document;
if (documentId != null) {
final Optional<DocumentAsset> documentOpt = documentAssetService.getAsset(documentId, permission);
if (documentOpt.isEmpty()) {
log.warn(String.format("Document %s not found", documentId));
throw new ResponseStatusException(HttpStatus.NOT_FOUND, messages.get("document.not-found"));
}

document = documentOpt.get();

// make sure there is text in the document
if (document.getText() == null || document.getText().isEmpty()) {
log.warn(String.format("Document %s has no text to send", documentId));
throw new ResponseStatusException(HttpStatus.NOT_FOUND, messages.get("document.extraction.not-done"));
}

// check for input length
if (document.getText().length() > ModelCardResponseHandler.MAX_TEXT_SIZE) {
log.warn(String.format("Document %s text too long for GoLLM model card task", documentId));
throw new ResponseStatusException(HttpStatus.BAD_REQUEST, messages.get("document.text-length-exceeded"));
}

input.setResearchPaper(document.getText());
throw new ResponseStatusException(HttpStatus.NOT_FOUND, messages.get("document.not-found"));
}

// Create the task
final TaskRequest req = new TaskRequest();
req.setType(TaskType.GOLLM);
req.setScript(ModelCardResponseHandler.NAME);
req.setUserId(currentUserService.get().getId());
final Optional<DocumentAsset> documentOpt = documentAssetService.getAsset(documentId, permission);
if (documentOpt.isEmpty()) {
log.warn(String.format("Document %s not found", documentId));
throw new ResponseStatusException(HttpStatus.NOT_FOUND, messages.get("document.not-found"));
}

try {
req.setInput(objectMapper.writeValueAsBytes(input));
} catch (final Exception e) {
log.error("Unable to serialize input", e);
throw new ResponseStatusException(HttpStatus.INTERNAL_SERVER_ERROR, messages.get("generic.io-error.write"));
// make sure there is text in the document
if (documentOpt.get().getText() == null || documentOpt.get().getText().isEmpty()) {
log.warn(String.format("Document %s has no text to send", documentId));
throw new ResponseStatusException(HttpStatus.NOT_FOUND, messages.get("document.extraction.not-done"));
}

req.setProjectId(projectId);
// check for input length
if (documentOpt.get().getText().length() > ModelCardResponseHandler.MAX_TEXT_SIZE) {
log.warn(String.format("Document %s text too long for GoLLM model card task", documentId));
throw new ResponseStatusException(HttpStatus.BAD_REQUEST, messages.get("document.text-length-exceeded"));
}

final ModelCardResponseHandler.Properties props = new ModelCardResponseHandler.Properties();
props.setProjectId(projectId);
props.setDocumentId(documentId);
props.setModelId(modelId);
req.setAdditionalProperties(props);
final TaskRequest req = getModelCardTask(documentOpt.get(), model.get(), projectId);

final TaskResponse resp;
try {
Expand Down Expand Up @@ -746,6 +722,101 @@ public ResponseEntity<TaskResponse> createGenerateResponseTask(
return ResponseEntity.ok().body(resp);
}

/**
* This endpoint will dispatch a few GoLLM tasks to enrich model metadata,
* including enriching the AMR and creating the model card
* @param modelId UUID of the model to enrich
* @param documentId UUID of the document to use for enrichment
* @param mode TaskMode to run the task in (is this ASYNC?)
* @param projectId UUID of the project to associate the task with for permissions
* @param overwrite boolean to determine if the model should be overwritten
* @return TaskResponse with the task ID
*/
@GetMapping("/enrich-model-metadata")
@Secured(Roles.USER)
@Operation(summary = "Dispatch multiple GoLLM tasks to enrich model metadata")
@ApiResponses(
value = {
@ApiResponse(
responseCode = "200",
description = "Dispatched successfully",
content = @Content(
mediaType = "application/json",
schema = @io.swagger.v3.oas.annotations.media.Schema(implementation = TaskResponse.class)
)
),
@ApiResponse(
responseCode = "404",
description = "The provided model or document arguments are not found",
content = @Content
),
@ApiResponse(responseCode = "500", description = "There was an issue dispatching the request", content = @Content)
}
)
public ResponseEntity<TaskResponse> createEnrichModelMetadataTask(
@RequestParam(name = "model-id", required = true) final UUID modelId,
@RequestParam(name = "document-id", required = false) final UUID documentId,
@RequestParam(name = "mode", required = false, defaultValue = "ASYNC") final TaskMode mode,
@RequestParam(name = "project-id", required = false) final UUID projectId,
@RequestParam(name = "overwrite", required = false, defaultValue = "false") final boolean overwrite
) {
final Schema.Permission permission = projectService.checkPermissionCanRead(
currentUserService.get().getId(),
projectId
);

// Grab the document
final Optional<DocumentAsset> document = documentAssetService.getAsset(documentId, permission);

// make sure there is text in the document. We don't need a document but if we do have one it can't be empty
if (document.isPresent() && (document.get().getText() == null || document.get().getText().isEmpty())) {
log.warn(String.format("Document %s has no extracted text", documentId));
throw new ResponseStatusException(HttpStatus.NOT_FOUND, messages.get("document.extraction.not-done"));
}

// Grab the model
final Optional<Model> model = modelService.getAsset(modelId, permission);
if (model.isEmpty()) {
log.warn(String.format("Model %s not found", modelId));
throw new ResponseStatusException(HttpStatus.NOT_FOUND, messages.get("model.not-found"));
}

final TaskRequest req;

if (document.isPresent()) {
dvince2 marked this conversation as resolved.
Show resolved Hide resolved
final TaskRequest enrichAmrRequest = getEnrichAMRTaskRequest(
document.orElse(null),
model.get(),
projectId,
overwrite
);
final TaskRequest modelCardRequest = getModelCardTask(document.orElse(null), model.get(), projectId);

req = new CompoundTask(enrichAmrRequest, modelCardRequest);
} else {
req = getModelCardTask(null, model.get(), projectId);
}

final TaskResponse resp;
try {
resp = taskService.runTask(mode, req);
} catch (final JsonProcessingException e) {
log.error("Unable to serialize input", e);
throw new ResponseStatusException(HttpStatus.INTERNAL_SERVER_ERROR, messages.get("task.gollm.json-processing"));
} catch (final TimeoutException e) {
log.warn("Timeout while waiting for task response", e);
throw new ResponseStatusException(HttpStatus.SERVICE_UNAVAILABLE, messages.get("task.gollm.timeout"));
} catch (final InterruptedException e) {
log.warn("Interrupted while waiting for task response", e);
throw new ResponseStatusException(HttpStatus.UNPROCESSABLE_ENTITY, messages.get("task.gollm.interrupted"));
} catch (final ExecutionException e) {
log.error("Error while waiting for task response", e);
throw new ResponseStatusException(HttpStatus.INTERNAL_SERVER_ERROR, messages.get("task.gollm.execution-failure"));
}

return ResponseEntity.ok().body(resp);
}

@GetMapping("/enrich-amr")
@Secured(Roles.USER)
@Operation(summary = "Dispatch a `GoLLM Enrich AMR` task")
Expand Down Expand Up @@ -799,41 +870,11 @@ public ResponseEntity<TaskResponse> createEnrichAMRTask(
throw new ResponseStatusException(HttpStatus.NOT_FOUND, messages.get("model.not-found"));
}

final EnrichAmrResponseHandler.Input input = new EnrichAmrResponseHandler.Input();
input.setResearchPaper(document.get().getText());
// stripping the metadata from the model before its sent since it can cause
// gollm to fail with massive inputs
model.get().setMetadata(null);

try {
final String amr = objectMapper.writeValueAsString(model.get());
input.setAmr(amr);
} catch (final JsonProcessingException e) {
log.error("Unable to serialize model card", e);
throw new ResponseStatusException(HttpStatus.INTERNAL_SERVER_ERROR, messages.get("task.gollm.json-processing"));
}

// Create the task
final TaskRequest req = new TaskRequest();
req.setType(TaskType.GOLLM);
req.setScript(EnrichAmrResponseHandler.NAME);
req.setUserId(currentUserService.get().getId());

try {
req.setInput(objectMapper.writeValueAsBytes(input));
} catch (final Exception e) {
log.error("Unable to serialize input", e);
throw new ResponseStatusException(HttpStatus.INTERNAL_SERVER_ERROR, messages.get("generic.io-error.write"));
}

req.setProjectId(projectId);

final EnrichAmrResponseHandler.Properties props = new EnrichAmrResponseHandler.Properties();
props.setProjectId(projectId);
props.setDocumentId(documentId);
props.setModelId(modelId);
props.setOverwrite(overwrite);
req.setAdditionalProperties(props);
TaskRequest req = getEnrichAMRTaskRequest(document.get(), model.get(), projectId, overwrite);

final TaskResponse resp;
try {
Expand Down Expand Up @@ -1071,4 +1112,71 @@ public ResponseEntity<Void> cancelTask(@PathVariable("task-id") final UUID taskI
taskService.cancelTask(taskId);
return ResponseEntity.ok().build();
}

private TaskRequest getModelCardTask(DocumentAsset document, Model model, UUID projectId) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should getModelCardTask and getEnrichAMRTaskRequest methods be in a service file instead of this controller?

final ModelCardResponseHandler.Input input = new ModelCardResponseHandler.Input();
input.setAmr(model.serializeWithoutTerariumFields(null, new String[] { "gollmCard" }));

if (document != null) input.setResearchPaper(document.getText());

// Create the task
final TaskRequest req = new TaskRequest();
req.setType(TaskType.GOLLM);
req.setScript(ModelCardResponseHandler.NAME);
req.setUserId(currentUserService.get().getId());

try {
req.setInput(objectMapper.writeValueAsBytes(input));
} catch (final Exception e) {
log.error("Unable to serialize input", e);
throw new ResponseStatusException(HttpStatus.INTERNAL_SERVER_ERROR, messages.get("generic.io-error.write"));
}

req.setProjectId(projectId);

final ModelCardResponseHandler.Properties props = new ModelCardResponseHandler.Properties();
props.setProjectId(projectId);
if (document != null) props.setDocumentId(document.getId());
props.setModelId(model.getId());
req.setAdditionalProperties(props);

return req;
}

private TaskRequest getEnrichAMRTaskRequest(DocumentAsset document, Model model, UUID projectId, Boolean overwrite) {
final EnrichAmrResponseHandler.Input input = new EnrichAmrResponseHandler.Input();
if (document != null) input.setResearchPaper(document.getText());

try {
final String amr = objectMapper.writeValueAsString(model);
input.setAmr(amr);
} catch (final JsonProcessingException e) {
log.error("Unable to serialize model card", e);
throw new ResponseStatusException(HttpStatus.INTERNAL_SERVER_ERROR, messages.get("task.gollm.json-processing"));
}

// Create the task
final TaskRequest req = new TaskRequest();
req.setType(TaskType.GOLLM);
req.setScript(EnrichAmrResponseHandler.NAME);
req.setUserId(currentUserService.get().getId());

try {
req.setInput(objectMapper.writeValueAsBytes(input));
} catch (final Exception e) {
log.error("Unable to serialize input", e);
throw new ResponseStatusException(HttpStatus.INTERNAL_SERVER_ERROR, messages.get("generic.io-error.write"));
}

req.setProjectId(projectId);

final EnrichAmrResponseHandler.Properties props = new EnrichAmrResponseHandler.Properties();
props.setProjectId(projectId);
if (document != null) props.setDocumentId(document.getId());
props.setModelId(model.getId());
props.setOverwrite(overwrite);
req.setAdditionalProperties(props);

return req;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
package software.uncharted.terarium.hmiserver.models.task;

import java.util.List;
import lombok.Data;

/**
* Represents a compound task that consists of a primary task and one or more secondary tasks.
*/
@Data
public class CompoundTask extends TaskRequest {

/**
* Constructs a CompoundTask with a primary task and optional secondary tasks.
*
* @param primaryTask the primary task
* @param secondaryTasks the secondary tasks
*/
public CompoundTask(TaskRequest primaryTask, TaskRequest... secondaryTasks) {
this.primaryTask = primaryTask;
this.secondaryTasks = List.of(secondaryTasks);
}

/** The primary task of the compound task. */
private TaskRequest primaryTask;

/** The list of secondary tasks of the compound task. */
private List<TaskRequest> secondaryTasks;
}
Loading
Loading