Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Model to latex with mira task #4831

Merged
merged 13 commits into from
Sep 18, 2024
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 2 additions & 6 deletions packages/client/hmi-client/src/services/model.ts
Original file line number Diff line number Diff line change
Expand Up @@ -174,12 +174,8 @@ export async function getModelEquation(model: Model): Promise<string> {
return '';
}

/* TODO - Replace the GET with the POST when the backend is ready,
* see PR https://github.com/DARPA-ASKEM/sciml-service/pull/167
*/
const response = await API.get(`/transforms/model-to-latex/${model.id}`);
// const response = await API.post(`/transforms/model-to-latex/`, model);
const latex = response?.data?.latex;
const response = await API.post(`/mira/model-to-latex`, model);
const latex = response?.data?.response;
if (!latex) return '';
return latex ?? '';
Tom-Szendrey marked this conversation as resolved.
Show resolved Hide resolved
}
Expand Down
1 change: 1 addition & 0 deletions packages/client/hmi-client/src/types/Types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1059,6 +1059,7 @@ export enum ClientEventType {
TaskFunmanValidation = "TASK_FUNMAN_VALIDATION",
TaskGollmEnrichAmr = "TASK_GOLLM_ENRICH_AMR",
TaskMiraAmrToMmt = "TASK_MIRA_AMR_TO_MMT",
TaskMiraGenerateModelLatex = "TASK_MIRA_GENERATE_MODEL_LATEX",
TaskEnrichAmr = "TASK_ENRICH_AMR",
WorkflowUpdate = "WORKFLOW_UPDATE",
WorkflowDelete = "WORKFLOW_DELETE",
Expand Down
1 change: 1 addition & 0 deletions packages/mira/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
"mira_task:mdl_to_stockflow=tasks.mdl_to_stockflow:main",
"mira_task:stella_to_stockflow=tasks.stella_to_stockflow:main",
"mira_task:amr_to_mmt=tasks.amr_to_mmt:main",
"mira_task:generate_model_latex=tasks.generate_model_latex:main",
],
},
python_requires=">=3.10",
Expand Down
76 changes: 76 additions & 0 deletions packages/mira/tasks/generate_model_latex.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
import sys
import json
import traceback
from taskrunner import TaskRunnerInterface
import sympy
from mira.sources.amr import model_from_json

def cleanup():
pass

def main():
exitCode = 0

try:
taskrunner = TaskRunnerInterface(description="Generate latex")
taskrunner.on_cancellation(cleanup)

data = taskrunner.read_input_str_with_timeout()
amr = json.loads(data)
model = model_from_json(amr)

odeterms = {var: 0 for var in model.get_concepts_name_map().keys()}

for t in model.templates:
if hasattr(t, "subject"):
var = t.subject.name
odeterms[var] -= t.rate_law.args[0]

if hasattr(t, "outcome"):
var = t.outcome.name
odeterms[var] += t.rate_law.args[0]

# Time
symb = lambda x: sympy.Symbol(x)
try:
time = model.time.name
except:
time = "t"
finally:
t = symb(time)
Tom-Szendrey marked this conversation as resolved.
Show resolved Hide resolved

# Observables
if len(model.observables) != 0:
obs_eqs = [
f"{{{obs.name}}}(t) = " + sympy.latex(obs.expression.args[0])
for obs in model.observables.values()
]

# Construct Sympy equations
odesys = [
sympy.latex(sympy.Eq(sympy.diff(sympy.Function(var)(t), t), terms))
for var, terms in odeterms.items()
]

#add observables.
odesys += obs_eqs
#Reformat:
odesys = "\\begin{align} \n " + " \\\\ \n ".join([eq for eq in odesys]) + "\n\\end{align}"

taskrunner.write_output_dict_with_timeout({"response": odesys})
print("Generate latex succeeded")

except Exception as e:
sys.stderr.write(f"Error: {str(e)}\n")
sys.stderr.write(traceback.format_exc())
sys.stderr.flush()
mwdchang marked this conversation as resolved.
Show resolved Hide resolved
exitCode = 1


taskrunner.log("Shutting down")
taskrunner.shutdown()
sys.exit(exitCode)


if __name__ == "__main__":
main()
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
import software.uncharted.terarium.hmiserver.service.data.ModelConfigurationService;
import software.uncharted.terarium.hmiserver.service.data.ProjectService;
import software.uncharted.terarium.hmiserver.service.tasks.AMRToMMTResponseHandler;
import software.uncharted.terarium.hmiserver.service.tasks.GenerateModelLatexResponseHandler;
import software.uncharted.terarium.hmiserver.service.tasks.MdlToStockflowResponseHandler;
import software.uncharted.terarium.hmiserver.service.tasks.SbmlToPetrinetResponseHandler;
import software.uncharted.terarium.hmiserver.service.tasks.StellaToStockflowResponseHandler;
Expand Down Expand Up @@ -167,6 +168,66 @@ public ResponseEntity<JsonNode> convertAMRtoMMT(@RequestBody final JsonNode mode
return ResponseEntity.ok().body(mmtInfo);
}

@PostMapping("/model-to-latex")
@Secured(Roles.USER)
@Operation(summary = "Generate latex from a model id")
@ApiResponses(
value = {
@ApiResponse(
responseCode = "200",
description = "Dispatched successfully",
content = @Content(
mediaType = "application/json",
schema = @io.swagger.v3.oas.annotations.media.Schema(implementation = TaskResponse.class)
)
),
@ApiResponse(responseCode = "500", description = "There was an issue dispatching the request", content = @Content)
}
)
public ResponseEntity<JsonNode> generateModelLatex(@RequestBody final JsonNode model) {
//create request:
final TaskRequest req = new TaskRequest();
req.setType(TaskType.MIRA);

try {
req.setInput(objectMapper.treeToValue(model, Model.class).serializeWithoutTerariumFields().getBytes());
} catch (final Exception e) {
log.error("Unable to serialize input", e);
throw new ResponseStatusException(HttpStatus.INTERNAL_SERVER_ERROR, messages.get("generic.io-error.write"));
}

req.setScript(GenerateModelLatexResponseHandler.NAME);
req.setUserId(currentUserService.get().getId());

// send the request
final TaskResponse resp;
try {
resp = taskService.runTaskSync(req);
} catch (final JsonProcessingException e) {
log.error("Unable to serialize input", e);
throw new ResponseStatusException(HttpStatus.INTERNAL_SERVER_ERROR, messages.get("task.mira.json-processing"));
} catch (final TimeoutException e) {
log.warn("Timeout while waiting for task response", e);
throw new ResponseStatusException(HttpStatus.SERVICE_UNAVAILABLE, messages.get("task.mira.timeout"));
} catch (final InterruptedException e) {
log.warn("Interrupted while waiting for task response", e);
throw new ResponseStatusException(HttpStatus.UNPROCESSABLE_ENTITY, messages.get("task.mira.interrupted"));
} catch (final ExecutionException e) {
log.error("Error while waiting for task response", e);
throw new ResponseStatusException(HttpStatus.INTERNAL_SERVER_ERROR, messages.get("task.mira.execution-failure"));
}

final JsonNode latexResponse;
try {
latexResponse = objectMapper.readValue(resp.getOutput(), JsonNode.class);
} catch (final IOException e) {
log.error("Unable to deserialize output", e);
throw new ResponseStatusException(HttpStatus.INTERNAL_SERVER_ERROR, messages.get("generic.io-error.read"));
}

return ResponseEntity.ok().body(latexResponse);
}

@PostMapping("/convert-and-create-model")
@Secured(Roles.USER)
@Operation(summary = "Dispatch a MIRA conversion task")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ public enum ClientEventType {
TASK_FUNMAN_VALIDATION,
TASK_GOLLM_ENRICH_AMR,
TASK_MIRA_AMR_TO_MMT,
TASK_MIRA_GENERATE_MODEL_LATEX,
TASK_ENRICH_AMR,
WORKFLOW_UPDATE,
WORKFLOW_DELETE
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
package software.uncharted.terarium.hmiserver.service.tasks;

import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Component;

@Component
@RequiredArgsConstructor
@Slf4j
public class GenerateModelLatexResponseHandler extends TaskResponseHandler {

public static final String NAME = "mira_task:generate_model_latex";

@Override
public String getName() {
return NAME;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@ public class TaskNotificationEventTypes {
EnrichAmrResponseHandler.NAME,
ClientEventType.TASK_GOLLM_ENRICH_AMR,
AMRToMMTResponseHandler.NAME,
ClientEventType.TASK_MIRA_AMR_TO_MMT
ClientEventType.TASK_MIRA_AMR_TO_MMT,
GenerateModelLatexResponseHandler.NAME,
ClientEventType.TASK_MIRA_GENERATE_MODEL_LATEX
);

public static ClientEventType getTypeFor(final String taskName) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,26 @@ public void testItCanSendAmrToMmtRequest() throws Exception {
log.info(new String(resp.getOutput()));
}

// @Test
@WithUserDetails(MockUser.URSULA)
public void testItCanSendGenerateModelLatexRequest() throws Exception {
final UUID taskId = UUID.randomUUID();

final ClassPathResource resource = new ClassPathResource("mira/problem.json");
final String content = new String(Files.readAllBytes(resource.getFile().toPath()));

final TaskRequest req = new TaskRequest();
req.setType(TaskType.MIRA);
req.setScript("mira_task:generate_model_latex");
req.setInput(content.getBytes());

final TaskResponse resp = taskService.runTaskSync(req);

Assertions.assertEquals(taskId, resp.getId());

log.info(new String(resp.getOutput()));
}

// @Test
@WithUserDetails(MockUser.URSULA)
public void testItCanCacheSuccess() throws Exception {
Expand Down
Loading