Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

use gollm=>mira equation to amr pipeline #5851

Merged
merged 4 commits into from
Dec 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion packages/client/hmi-client/src/temp/Equations.vue
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ const latex2amr = async () => {
resultAmr.value.innerHTML = 'processing...';
}

const resp = await API.post('/mira/latex-to-amr', equations);
const resp = await API.post('/knowledge/equations-to-model-debug', equations);
const respData = resp.data;

if (resultCode.value) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,9 @@
import software.uncharted.terarium.hmiserver.service.notification.NotificationService;
import software.uncharted.terarium.hmiserver.service.tasks.EquationsCleanupResponseHandler;
import software.uncharted.terarium.hmiserver.service.tasks.LatexToAMRResponseHandler;
// latex to model chain
import software.uncharted.terarium.hmiserver.service.tasks.LatexToSympyResponseHandler;
import software.uncharted.terarium.hmiserver.service.tasks.SympyToAMRResponseHandler;
import software.uncharted.terarium.hmiserver.service.tasks.TaskService;
import software.uncharted.terarium.hmiserver.service.tasks.TaskService.TaskMode;
import software.uncharted.terarium.hmiserver.utils.ByteMultipartFile;
Expand Down Expand Up @@ -267,19 +270,35 @@ public ResponseEntity<UUID> equationsToModel(
}

if (extractionService.equals("mira")) {
final TaskRequest taskReq = new TaskRequest();
final String latex = req.get("equations").toString();
taskReq.setType(TaskType.MIRA);
try {
taskReq.setInput(latex.getBytes());
taskReq.setScript(LatexToAMRResponseHandler.NAME);
taskReq.setUserId(currentUserService.get().getId());
final TaskResponse taskResp = taskService.runTaskSync(taskReq);
final JsonNode taskResponseJSON = mapper.readValue(taskResp.getOutput(), JsonNode.class);
final TaskRequest latexToSympyRequest = new TaskRequest();
final TaskResponse latexToSympyResponse;
final TaskRequest sympyToAMRRequest = new TaskRequest();
final TaskResponse sympyToAMRResponse;

try {
// 1. LaTeX to sympy code
final String latex = req.get("equations").toString();
latexToSympyRequest.setType(TaskType.GOLLM);
latexToSympyRequest.setInput(latex.getBytes());
latexToSympyRequest.setScript(LatexToSympyResponseHandler.NAME);
latexToSympyRequest.setUserId(currentUserService.get().getId());
latexToSympyResponse = taskService.runTaskSync(latexToSympyRequest);

// 2. hand off
final JsonNode node = mapper.readValue(latexToSympyResponse.getOutput(), JsonNode.class);
final String code = node.get("response").asText();

// 3. sympy code string to amr json
sympyToAMRRequest.setType(TaskType.MIRA);
sympyToAMRRequest.setInput(code.getBytes());
sympyToAMRRequest.setScript(SympyToAMRResponseHandler.NAME);
sympyToAMRRequest.setUserId(currentUserService.get().getId());
sympyToAMRResponse = taskService.runTaskSync(sympyToAMRRequest);

final JsonNode taskResponseJSON = mapper.readValue(sympyToAMRResponse.getOutput(), JsonNode.class);
final ObjectNode amrNode = taskResponseJSON.get("response").get("amr").deepCopy();
responseAMR = mapper.convertValue(amrNode, Model.class);
} catch (Exception e) {
} catch (final Exception e) {
log.error("failed to convert LaTeX equations to AMR", e);
throw new ResponseStatusException(HttpStatus.INTERNAL_SERVER_ERROR, "failed to convert latex equations to AMR");
}
Expand Down Expand Up @@ -868,6 +887,95 @@ public ResponseEntity<Void> pdfExtractions(
return ResponseEntity.accepted().build();
}

@PostMapping("/equations-to-model-debug")
@Secured(Roles.USER)
@Operation(summary = "Generate AMR from latex ODE equations, DEBUGGING only")
@ApiResponses(
value = {
@ApiResponse(
responseCode = "200",
description = "Dispatched successfully",
content = @Content(
mediaType = "application/json",
schema = @io.swagger.v3.oas.annotations.media.Schema(implementation = TaskResponse.class)
)
),
@ApiResponse(responseCode = "500", description = "There was an issue dispatching the request", content = @Content)
}
)
/**
* This is similiar to /equations-to-model endpoint, but rather than
* directly creating a model asset it returns artifact at different intersections
* and handoff for debugging
**/
public ResponseEntity<JsonNode> latexToAMR(@RequestBody final String latex) {
////////////////////////////////////////////////////////////////////////////////
// 1. Convert latex string to python sympy code string
//
// Note this is a gollm string => string task
////////////////////////////////////////////////////////////////////////////////
final TaskRequest latexToSympyRequest = new TaskRequest();
final TaskResponse latexToSympyResponse;
String code = null;

try {
latexToSympyRequest.setType(TaskType.GOLLM);
latexToSympyRequest.setInput(latex.getBytes());
latexToSympyRequest.setScript(LatexToSympyResponseHandler.NAME);
latexToSympyRequest.setUserId(currentUserService.get().getId());
latexToSympyResponse = taskService.runTaskSync(latexToSympyRequest);

final JsonNode node = mapper.readValue(latexToSympyResponse.getOutput(), JsonNode.class);
code = node.get("response").asText();
} catch (final TimeoutException e) {
log.warn("Timeout while waiting for task response", e);
throw new ResponseStatusException(HttpStatus.SERVICE_UNAVAILABLE, messages.get("task.gollm.timeout"));
} catch (final InterruptedException e) {
log.warn("Interrupted while waiting for task response", e);
throw new ResponseStatusException(HttpStatus.INTERNAL_SERVER_ERROR, messages.get("task.gollm.interrupted"));
} catch (final ExecutionException e) {
log.error("Error while waiting for task response", e);
throw new ResponseStatusException(HttpStatus.INTERNAL_SERVER_ERROR, messages.get("task.gollm.execution-failure"));
} catch (final Exception e) {
log.error("Unexpected error", e);
}

if (code == null) {
throw new ResponseStatusException(HttpStatus.INTERNAL_SERVER_ERROR, messages.get("task.gollm.execution-failure"));
}

////////////////////////////////////////////////////////////////////////////////
// 2. Convert python sympy code string to amr
//
// This returns the AMR json, and intermediate data representations for debugging
////////////////////////////////////////////////////////////////////////////////
final TaskRequest sympyToAMRRequest = new TaskRequest();
final TaskResponse sympyToAMRResponse;
final JsonNode response;

try {
sympyToAMRRequest.setType(TaskType.MIRA);
sympyToAMRRequest.setInput(code.getBytes());
sympyToAMRRequest.setScript(SympyToAMRResponseHandler.NAME);
sympyToAMRRequest.setUserId(currentUserService.get().getId());
sympyToAMRResponse = taskService.runTaskSync(sympyToAMRRequest);
response = mapper.readValue(sympyToAMRResponse.getOutput(), JsonNode.class);
return ResponseEntity.ok().body(response);
} catch (final TimeoutException e) {
log.warn("Timeout while waiting for task response", e);
throw new ResponseStatusException(HttpStatus.SERVICE_UNAVAILABLE, messages.get("task.mira.timeout"));
} catch (final InterruptedException e) {
log.warn("Interrupted while waiting for task response", e);
throw new ResponseStatusException(HttpStatus.INTERNAL_SERVER_ERROR, messages.get("task.mira.interrupted"));
} catch (final ExecutionException e) {
log.error("Error while waiting for task response", e);
throw new ResponseStatusException(HttpStatus.INTERNAL_SERVER_ERROR, messages.get("task.mira.execution-failure"));
} catch (final Exception e) {
log.error("Unexpected error", e);
}
throw new ResponseStatusException(HttpStatus.INTERNAL_SERVER_ERROR, messages.get("generic.io-error.read"));
}

private ResponseStatusException handleSkemaFeignException(final FeignException e) {
final HttpStatus statusCode = HttpStatus.resolve(e.status());
if (statusCode != null && statusCode.is4xxClientError()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,11 +49,9 @@
import software.uncharted.terarium.hmiserver.service.tasks.AMRToMMTResponseHandler;
import software.uncharted.terarium.hmiserver.service.tasks.GenerateModelLatexResponseHandler;
import software.uncharted.terarium.hmiserver.service.tasks.LatexToAMRResponseHandler;
import software.uncharted.terarium.hmiserver.service.tasks.LatexToSympyResponseHandler;
import software.uncharted.terarium.hmiserver.service.tasks.MdlToStockflowResponseHandler;
import software.uncharted.terarium.hmiserver.service.tasks.SbmlToPetrinetResponseHandler;
import software.uncharted.terarium.hmiserver.service.tasks.StellaToStockflowResponseHandler;
import software.uncharted.terarium.hmiserver.service.tasks.SympyToAMRResponseHandler;
import software.uncharted.terarium.hmiserver.service.tasks.TaskService;
import software.uncharted.terarium.hmiserver.utils.Messages;
import software.uncharted.terarium.hmiserver.utils.rebac.Schema;
Expand Down Expand Up @@ -254,90 +252,6 @@ public ResponseEntity<JsonNode> generateModelLatex(@RequestBody final JsonNode m
return ResponseEntity.ok().body(latexResponse);
}

@PostMapping("/latex-to-amr")
@Secured(Roles.USER)
@Operation(summary = "Generate AMR from latex ODE equations")
@ApiResponses(
value = {
@ApiResponse(
responseCode = "200",
description = "Dispatched successfully",
content = @Content(
mediaType = "application/json",
schema = @io.swagger.v3.oas.annotations.media.Schema(implementation = TaskResponse.class)
)
),
@ApiResponse(responseCode = "500", description = "There was an issue dispatching the request", content = @Content)
}
)
public ResponseEntity<JsonNode> latexToAMR(@RequestBody final String latex) {
////////////////////////////////////////////////////////////////////////////////
// 1. Convert latex string to python sympy code string
//
// Note this is a gollm string => string task
////////////////////////////////////////////////////////////////////////////////
final TaskRequest latexToSympyRequest = new TaskRequest();
final TaskResponse latexToSympyResponse;
String code = null;

try {
latexToSympyRequest.setType(TaskType.GOLLM);
latexToSympyRequest.setInput(latex.getBytes());
latexToSympyRequest.setScript(LatexToSympyResponseHandler.NAME);
latexToSympyRequest.setUserId(currentUserService.get().getId());
latexToSympyResponse = taskService.runTaskSync(latexToSympyRequest);

final JsonNode node = objectMapper.readValue(latexToSympyResponse.getOutput(), JsonNode.class);
code = node.get("response").asText();
} catch (final TimeoutException e) {
log.warn("Timeout while waiting for task response", e);
throw new ResponseStatusException(HttpStatus.SERVICE_UNAVAILABLE, messages.get("task.gollm.timeout"));
} catch (final InterruptedException e) {
log.warn("Interrupted while waiting for task response", e);
throw new ResponseStatusException(HttpStatus.INTERNAL_SERVER_ERROR, messages.get("task.gollm.interrupted"));
} catch (final ExecutionException e) {
log.error("Error while waiting for task response", e);
throw new ResponseStatusException(HttpStatus.INTERNAL_SERVER_ERROR, messages.get("task.gollm.execution-failure"));
} catch (final Exception e) {
log.error("Unexpected error", e);
}

if (code == null) {
throw new ResponseStatusException(HttpStatus.INTERNAL_SERVER_ERROR, messages.get("task.gollm.execution-failure"));
}

////////////////////////////////////////////////////////////////////////////////
// 2. Convert python sympy code string to amr
//
// This returns the AMR json, and intermediate data representations for debugging
////////////////////////////////////////////////////////////////////////////////
final TaskRequest sympyToAMRRequest = new TaskRequest();
final TaskResponse sympyToAMRResponse;
final JsonNode response;

try {
sympyToAMRRequest.setType(TaskType.MIRA);
sympyToAMRRequest.setInput(code.getBytes());
sympyToAMRRequest.setScript(SympyToAMRResponseHandler.NAME);
sympyToAMRRequest.setUserId(currentUserService.get().getId());
sympyToAMRResponse = taskService.runTaskSync(sympyToAMRRequest);
response = objectMapper.readValue(sympyToAMRResponse.getOutput(), JsonNode.class);
return ResponseEntity.ok().body(response);
} catch (final TimeoutException e) {
log.warn("Timeout while waiting for task response", e);
throw new ResponseStatusException(HttpStatus.SERVICE_UNAVAILABLE, messages.get("task.mira.timeout"));
} catch (final InterruptedException e) {
log.warn("Interrupted while waiting for task response", e);
throw new ResponseStatusException(HttpStatus.INTERNAL_SERVER_ERROR, messages.get("task.mira.interrupted"));
} catch (final ExecutionException e) {
log.error("Error while waiting for task response", e);
throw new ResponseStatusException(HttpStatus.INTERNAL_SERVER_ERROR, messages.get("task.mira.execution-failure"));
} catch (final Exception e) {
log.error("Unexpected error", e);
}
throw new ResponseStatusException(HttpStatus.INTERNAL_SERVER_ERROR, messages.get("generic.io-error.read"));
}

@PostMapping("/convert-and-create-model")
@Secured(Roles.USER)
@Operation(summary = "Dispatch a MIRA conversion task")
Expand Down