Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

if we can't find a grounding, we set as a default empty grounding #5367

Merged
merged 5 commits into from
Nov 5, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion packages/client/hmi-client/src/services/goLLM.ts
YohannParis marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ export async function interventionPolicyFromDocument(

export async function enrichModelMetadata(modelId: string, documentId: string, overwrite: boolean): Promise<void> {
try {
await API.get<TaskResponse>('/gollm/enrich-model-metadata', {
await API.get('/gollm/enrich-model-metadata', {
params: {
'model-id': modelId,
'document-id': documentId,
Expand Down
1 change: 1 addition & 0 deletions packages/gollm/gollm_openai/prompts/latex_style_guide.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,5 @@
11) Do not use "\\cdot" or "*" to indicate multiplication. Use whitespace instead.
12) Replace "\\epsilon" with "\\varepsilon" when representing a parameter or variable
13) Avoid using notation for mathematical constants like "e" and "pi". Use their actual values up to 3 decimal places instead.
14) If equations are separated by commas, do not include commas in the LaTeX code.
"""
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,7 @@
import software.uncharted.terarium.hmiserver.models.dataservice.dataset.Dataset;
import software.uncharted.terarium.hmiserver.models.dataservice.document.DocumentAsset;
import software.uncharted.terarium.hmiserver.models.dataservice.model.Model;
import software.uncharted.terarium.hmiserver.models.dataservice.modelparts.ModelParameter;
import software.uncharted.terarium.hmiserver.models.dataservice.modelparts.semantics.Observable;
import software.uncharted.terarium.hmiserver.models.dataservice.modelparts.semantics.State;
import software.uncharted.terarium.hmiserver.models.dataservice.modelparts.semantics.Transition;
import software.uncharted.terarium.hmiserver.models.dataservice.regnet.RegNetVertex;
import software.uncharted.terarium.hmiserver.models.task.CompoundTask;
import software.uncharted.terarium.hmiserver.models.dataservice.modelparts.ModelMetadata;
import software.uncharted.terarium.hmiserver.models.task.TaskRequest;
import software.uncharted.terarium.hmiserver.models.task.TaskRequest.TaskType;
import software.uncharted.terarium.hmiserver.models.task.TaskResponse;
Expand Down Expand Up @@ -269,7 +264,7 @@ public ResponseEntity<TaskResponse> createConfigureModelFromDocumentTask(

// stripping the metadata from the model before its sent since it can cause
// gollm to fail with massive inputs
model.get().setMetadata(null);
model.get().setMetadata(new ModelMetadata());
input.setAmr(model.get().serializeWithoutTerariumFields(new String[] { "id" }, null));

// Create the task
Expand Down Expand Up @@ -397,7 +392,7 @@ public ResponseEntity<TaskResponse> createConfigureModelFromDatasetTask(
input.setDataset(dataArray);
// stripping the metadata from the model before its sent since it can cause
// gollm to fail with massive inputs
model.get().setMetadata(null);
model.get().setMetadata(new ModelMetadata());
input.setAmr(model.get().serializeWithoutTerariumFields(null, null));

// set matrix string if provided
Expand Down Expand Up @@ -512,7 +507,7 @@ public ResponseEntity<TaskResponse> createInterventionsFromDocumentTask(

// stripping the metadata from the model before its sent since it can cause
// gollm to fail with massive inputs
model.get().setMetadata(null);
model.get().setMetadata(new ModelMetadata());
input.setAmr(model.get().serializeWithoutTerariumFields(new String[] { "id" }, null));

// Create the task
Expand Down Expand Up @@ -748,7 +743,7 @@ public ResponseEntity<TaskResponse> createGenerateResponseTask(
description = "Dispatched successfully",
content = @Content(
mediaType = "application/json",
schema = @io.swagger.v3.oas.annotations.media.Schema(implementation = TaskResponse.class)
schema = @io.swagger.v3.oas.annotations.media.Schema(implementation = UUID.class)
)
),
@ApiResponse(
Expand All @@ -759,10 +754,10 @@ public ResponseEntity<TaskResponse> createGenerateResponseTask(
@ApiResponse(responseCode = "500", description = "There was an issue dispatching the request", content = @Content)
}
)
public ResponseEntity<TaskResponse> createEnrichModelMetadataTask(
public ResponseEntity<UUID> createEnrichModelMetadataTask(
@RequestParam(name = "model-id", required = true) final UUID modelId,
@RequestParam(name = "document-id", required = false) final UUID documentId,
@RequestParam(name = "mode", required = false, defaultValue = "ASYNC") final TaskMode mode,
@RequestParam(name = "mode", required = false, defaultValue = "SYNC") final TaskMode mode,
@RequestParam(name = "project-id", required = false) final UUID projectId,
@RequestParam(name = "overwrite", required = false, defaultValue = "false") final boolean overwrite
) {
Expand All @@ -771,129 +766,30 @@ public ResponseEntity<TaskResponse> createEnrichModelMetadataTask(
projectId
);

// Grab the document
final Optional<DocumentAsset> document = documentAssetService.getAsset(documentId, permission);

// make sure there is text in the document. We don't need a document but if we
// do have one it can't be empty
if (document.isPresent() && (document.get().getText() == null || document.get().getText().isEmpty())) {
log.warn(String.format("Document %s has no extracted text", documentId));
throw new ResponseStatusException(HttpStatus.NOT_FOUND, messages.get("document.extraction.not-done"));
}

// Grab the model
Optional<Model> modelOptional = modelService.getAsset(modelId, permission);
if (modelOptional.isEmpty()) {
log.warn(String.format("Model %s not found", modelId));
throw new ResponseStatusException(HttpStatus.NOT_FOUND, messages.get("model.not-found"));
}

final TaskRequest req;

if (document.isPresent()) {
final TaskRequest enrichAmrRequest;
try {
enrichAmrRequest = TaskUtilities.getEnrichAMRTaskRequest(
currentUserService.get().getId(),
document.orElse(null),
modelOptional.get(),
projectId,
overwrite
);
} catch (final IOException e) {
log.error("Unable to create Enrich AMR task", e);
throw new ResponseStatusException(HttpStatus.INTERNAL_SERVER_ERROR, messages.get("generic.io-error.write"));
}

final TaskRequest modelCardRequest;
try {
modelCardRequest = TaskUtilities.getModelCardTask(
currentUserService.get().getId(),
document.orElse(null),
modelOptional.get(),
projectId
);
} catch (final IOException e) {
log.error("Unable to create Model Card task", e);
throw new ResponseStatusException(HttpStatus.INTERNAL_SERVER_ERROR, messages.get("generic.io-error.write"));
}

req = new CompoundTask(enrichAmrRequest, modelCardRequest);
} else {
try {
req = TaskUtilities.getModelCardTask(currentUserService.get().getId(), null, modelOptional.get(), projectId);
} catch (final IOException e) {
log.error("Unable to create Model Card task", e);
throw new ResponseStatusException(HttpStatus.INTERNAL_SERVER_ERROR, messages.get("generic.io-error.write"));
}
}

final TaskResponse resp;
try {
resp = taskService.runTask(mode, req);
} catch (final JsonProcessingException e) {
log.error("Unable to serialize input", e);
throw new ResponseStatusException(HttpStatus.INTERNAL_SERVER_ERROR, messages.get("task.gollm.json-processing"));
} catch (final TimeoutException e) {
log.warn("Timeout while waiting for task response", e);
throw new ResponseStatusException(HttpStatus.SERVICE_UNAVAILABLE, messages.get("task.gollm.timeout"));
} catch (final InterruptedException e) {
log.warn("Interrupted while waiting for task response", e);
throw new ResponseStatusException(HttpStatus.UNPROCESSABLE_ENTITY, messages.get("task.gollm.interrupted"));
} catch (final ExecutionException e) {
modelService.enrichModel(projectId, documentId, modelId, permission, true);
} catch (final IOException e) {
log.error("An error occurred while trying to retrieve information necessary for model enrichment.", e);
throw new ResponseStatusException(HttpStatus.SERVICE_UNAVAILABLE, messages.get("postgres.service-unavailable"));
} catch (ExecutionException e) {
log.error("Error while waiting for task response", e);
throw new ResponseStatusException(HttpStatus.INTERNAL_SERVER_ERROR, messages.get("task.gollm.execution-failure"));
} catch (InterruptedException e) {
log.warn("Interrupted while waiting for task response", e);
throw new ResponseStatusException(HttpStatus.UNPROCESSABLE_ENTITY, messages.get("task.gollm.interrupted"));
} catch (TimeoutException e) {
log.warn("Timeout while waiting for task response", e);
throw new ResponseStatusException(HttpStatus.SERVICE_UNAVAILABLE, messages.get("task.gollm.timeout"));
}

// at this point the initial enrichment has happened.
modelOptional = modelService.getAsset(modelId, permission);
if (modelOptional.isEmpty()) {
// this would be a very strange case
log.warn(String.format("Model %s not found", modelId));
throw new ResponseStatusException(HttpStatus.NOT_FOUND, messages.get("model.not-found"));
}

final Model model = modelOptional.get();

// Update State Grounding
if (!model.isRegnet()) {
final List<State> states = model.getStates();
states.forEach(state -> TaskUtilities.performDKGSearchAndSetGrounding(miraProxy, state));
model.setStates(states);
} else if (model.isRegnet()) {
final List<RegNetVertex> vertices = model.getVerticies();
vertices.forEach(vertex -> TaskUtilities.performDKGSearchAndSetGrounding(miraProxy, vertex));
model.setVerticies(vertices);
}

// Update Observable Grounding
if (model.getObservables() != null && !model.getObservables().isEmpty()) {
final List<Observable> observables = model.getObservables();
observables.forEach(observable -> TaskUtilities.performDKGSearchAndSetGrounding(miraProxy, observable));
model.setObservables(observables);
}

// Update Parameter Grounding
if (model.getParameters() != null && !model.getParameters().isEmpty()) {
final List<ModelParameter> parameters = model.getParameters();
parameters.forEach(parameter -> TaskUtilities.performDKGSearchAndSetGrounding(miraProxy, parameter));
model.setParameters(parameters);
}

// Update Transition Grounding
if (model.getTransitions() != null && !model.getTransitions().isEmpty()) {
final List<Transition> transitions = model.getTransitions();
transitions.forEach(transition -> TaskUtilities.performDKGSearchAndSetGrounding(miraProxy, transition));
model.setTransitions(transitions);
}

try {
modelService.updateAsset(model, projectId, permission);
} catch (final IOException e) {
throw new RuntimeException(e);
// check that we have a model to return
final Optional<Model> model = modelService.getAsset(modelId, permission);
if (model.isEmpty()) {
log.error(String.format("The model %s does not exist.", modelId));
throw new ResponseStatusException(HttpStatus.BAD_REQUEST, messages.get("model.not-found"));
}

return ResponseEntity.ok().body(resp);
return ResponseEntity.ok().body(model.get().getId());
}

@GetMapping("/enrich-amr")
Expand Down
Loading
Loading