diff --git a/packages/client/hmi-client/src/types/Types.ts b/packages/client/hmi-client/src/types/Types.ts index cbb4515fdc..9a568ae891 100644 --- a/packages/client/hmi-client/src/types/Types.ts +++ b/packages/client/hmi-client/src/types/Types.ts @@ -1157,6 +1157,7 @@ export enum ProvenanceType { Document = "Document", Workflow = "Workflow", Equation = "Equation", + InterventionPolicy = "InterventionPolicy", } export enum SimulationType { diff --git a/packages/gollm/tasks/interventions_from_document.py b/packages/gollm/tasks/interventions_from_document.py new file mode 100644 index 0000000000..7c847250fd --- /dev/null +++ b/packages/gollm/tasks/interventions_from_document.py @@ -0,0 +1,44 @@ +import json +import sys +from gollm.entities import InterventionsFromDocument +from gollm.openai.tool_utils import interventions_from_document + +from taskrunner import TaskRunnerInterface + + +def cleanup(): + pass + + +def main(): + exitCode = 0 + try: + taskrunner = TaskRunnerInterface(description="Extract interventions from paper CLI") + taskrunner.on_cancellation(cleanup) + + input_dict = taskrunner.read_input_dict_with_timeout() + + taskrunner.log("Creating InterventionsFromDocument model from input") + input_model = InterventionsFromDocument(**input_dict) + amr = json.dumps(input_model.amr, separators=(",", ":")) + + taskrunner.log("Sending request to OpenAI API") + response = interventions_from_document( + research_paper=input_model.research_paper, amr=amr + ) + taskrunner.log("Received response from OpenAI API") + + taskrunner.write_output_dict_with_timeout({"response": response}) + + except Exception as e: + sys.stderr.write(f"Error: {str(e)}\n") + sys.stderr.flush() + exitCode = 1 + + taskrunner.log("Shutting down") + taskrunner.shutdown() + sys.exit(exitCode) + + +if __name__ == "__main__": + main() diff --git a/packages/server/src/main/java/software/uncharted/terarium/hmiserver/controller/dataservice/SimulationController.java b/packages/server/src/main/java/software/uncharted/terarium/hmiserver/controller/dataservice/SimulationController.java index e2c62f789c..28c4da0f74 100644 --- a/packages/server/src/main/java/software/uncharted/terarium/hmiserver/controller/dataservice/SimulationController.java +++ b/packages/server/src/main/java/software/uncharted/terarium/hmiserver/controller/dataservice/SimulationController.java @@ -6,7 +6,6 @@ import io.swagger.v3.oas.annotations.media.Content; import io.swagger.v3.oas.annotations.responses.ApiResponse; import io.swagger.v3.oas.annotations.responses.ApiResponses; -import jakarta.transaction.Transactional; import java.io.IOException; import java.util.ArrayList; import java.util.List; @@ -370,7 +369,7 @@ public ResponseEntity createFromSimulationResult( datasetService.updateAsset(dataset, projectId, permission); // If this is a temporary asset, do not add to project. - if (addToProject == false) { + if (!addToProject) { return ResponseEntity.status(HttpStatus.CREATED).body(dataset); } diff --git a/packages/server/src/main/java/software/uncharted/terarium/hmiserver/controller/gollm/GoLLMController.java b/packages/server/src/main/java/software/uncharted/terarium/hmiserver/controller/gollm/GoLLMController.java index 8cee8d6b43..ec1fdda042 100644 --- a/packages/server/src/main/java/software/uncharted/terarium/hmiserver/controller/gollm/GoLLMController.java +++ b/packages/server/src/main/java/software/uncharted/terarium/hmiserver/controller/gollm/GoLLMController.java @@ -57,6 +57,7 @@ import software.uncharted.terarium.hmiserver.service.tasks.EquationsFromImageResponseHandler; import software.uncharted.terarium.hmiserver.service.tasks.GenerateResponseHandler; import software.uncharted.terarium.hmiserver.service.tasks.GenerateSummaryHandler; +import software.uncharted.terarium.hmiserver.service.tasks.InterventionsFromDocumentResponseHandler; import software.uncharted.terarium.hmiserver.service.tasks.ModelCardResponseHandler; import software.uncharted.terarium.hmiserver.service.tasks.TaskService; import software.uncharted.terarium.hmiserver.service.tasks.TaskService.TaskMode; @@ -448,6 +449,112 @@ public ResponseEntity createConfigureModelFromDatasetTask( return ResponseEntity.ok().body(resp); } + @PostMapping("/interventions-from-document") + @Secured(Roles.USER) + @Operation(summary = "Dispatch a `GoLLM interventions-from-document` task") + @ApiResponses( + value = { + @ApiResponse( + responseCode = "200", + description = "Dispatched successfully", + content = @Content( + mediaType = "application/json", + schema = @io.swagger.v3.oas.annotations.media.Schema(implementation = TaskResponse.class) + ) + ), + @ApiResponse( + responseCode = "404", + description = "The provided model or document arguments are not found", + content = @Content + ), + @ApiResponse(responseCode = "500", description = "There was an issue dispatching the request", content = @Content) + } + ) + public ResponseEntity createInterventionsFromDocumentTask( + @RequestParam(name = "model-id", required = true) final UUID modelId, + @RequestParam(name = "document-id", required = true) final UUID documentId, + @RequestParam(name = "mode", required = false, defaultValue = "ASYNC") final TaskMode mode, + @RequestParam(name = "workflow-id", required = false) final UUID workflowId, + @RequestParam(name = "node-id", required = false) final UUID nodeId, + @RequestParam(name = "project-id", required = false) final UUID projectId + ) { + final Schema.Permission permission = projectService.checkPermissionCanRead( + currentUserService.get().getId(), + projectId + ); + + // Grab the document + final Optional document = documentAssetService.getAsset(documentId, permission); + if (document.isEmpty()) { + log.warn(String.format("Document %s not found", documentId)); + throw new ResponseStatusException(HttpStatus.NOT_FOUND, messages.get("document.not-found")); + } + + // make sure there is text in the document + if (document.get().getText() == null || document.get().getText().isEmpty()) { + log.warn(String.format("Document %s has no extracted text", documentId)); + throw new ResponseStatusException(HttpStatus.NOT_FOUND, messages.get("document.extraction.not-done")); + } + + // Grab the model + final Optional model = modelService.getAsset(modelId, permission); + if (model.isEmpty()) { + log.warn(String.format("Model %s not found", modelId)); + throw new ResponseStatusException(HttpStatus.NOT_FOUND, messages.get("model.not-found")); + } + + final InterventionsFromDocumentResponseHandler.Input input = new InterventionsFromDocumentResponseHandler.Input(); + input.setResearchPaper(document.get().getText()); + + // stripping the metadata from the model before its sent since it can cause + // gollm to fail with massive inputs + model.get().setMetadata(null); + input.setAmr(model.get().serializeWithoutTerariumFieldsKeepId()); + + // Create the task + final TaskRequest req = new TaskRequest(); + req.setType(TaskType.GOLLM); + req.setScript(InterventionsFromDocumentResponseHandler.NAME); + req.setUserId(currentUserService.get().getId()); + + try { + req.setInput(objectMapper.writeValueAsBytes(input)); + } catch (final Exception e) { + log.error("Unable to serialize input", e); + throw new ResponseStatusException(HttpStatus.INTERNAL_SERVER_ERROR, messages.get("generic.io-error.write")); + } + + req.setProjectId(projectId); + + final InterventionsFromDocumentResponseHandler.Properties props = + new InterventionsFromDocumentResponseHandler.Properties(); + props.setProjectId(projectId); + props.setDocumentId(documentId); + props.setModelId(modelId); + props.setWorkflowId(workflowId); + props.setNodeId(nodeId); + req.setAdditionalProperties(props); + + final TaskResponse resp; + try { + resp = taskService.runTask(mode, req); + } catch (final JsonProcessingException e) { + log.error("Unable to serialize input", e); + throw new ResponseStatusException(HttpStatus.INTERNAL_SERVER_ERROR, messages.get("task.gollm.json-processing")); + } catch (final TimeoutException e) { + log.warn("Timeout while waiting for task response", e); + throw new ResponseStatusException(HttpStatus.SERVICE_UNAVAILABLE, messages.get("task.gollm.timeout")); + } catch (final InterruptedException e) { + log.warn("Interrupted while waiting for task response", e); + throw new ResponseStatusException(HttpStatus.UNPROCESSABLE_ENTITY, messages.get("task.gollm.interrupted")); + } catch (final ExecutionException e) { + log.error("Error while waiting for task response", e); + throw new ResponseStatusException(HttpStatus.INTERNAL_SERVER_ERROR, messages.get("task.gollm.execution-failure")); + } + + return ResponseEntity.ok().body(resp); + } + @GetMapping("/compare-models") @Secured(Roles.USER) @Operation(summary = "Dispatch a `GoLLM Compare Models` task") diff --git a/packages/server/src/main/java/software/uncharted/terarium/hmiserver/models/dataservice/Identifier.java b/packages/server/src/main/java/software/uncharted/terarium/hmiserver/models/dataservice/Identifier.java index e45bb4c238..be0de37cb3 100644 --- a/packages/server/src/main/java/software/uncharted/terarium/hmiserver/models/dataservice/Identifier.java +++ b/packages/server/src/main/java/software/uncharted/terarium/hmiserver/models/dataservice/Identifier.java @@ -1,9 +1,11 @@ package software.uncharted.terarium.hmiserver.models.dataservice; +import java.io.Serial; import java.io.Serializable; import software.uncharted.terarium.hmiserver.annotations.TSModel; @TSModel public record Identifier(String curie, String name) implements Serializable { + @Serial private static final long serialVersionUID = 302308407252037615L; } diff --git a/packages/server/src/main/java/software/uncharted/terarium/hmiserver/models/dataservice/provenance/ProvenanceType.java b/packages/server/src/main/java/software/uncharted/terarium/hmiserver/models/dataservice/provenance/ProvenanceType.java index f38acdf9c6..e63655e72b 100644 --- a/packages/server/src/main/java/software/uncharted/terarium/hmiserver/models/dataservice/provenance/ProvenanceType.java +++ b/packages/server/src/main/java/software/uncharted/terarium/hmiserver/models/dataservice/provenance/ProvenanceType.java @@ -45,7 +45,10 @@ public enum ProvenanceType { WORKFLOW("Workflow"), @JsonAlias("Equation") - EQUATION("Equation"); + EQUATION("Equation"), + + @JsonAlias("InterventionPolicy") + INTERVENTION_POLICY("InterventionPolicy"); public final String type; diff --git a/packages/server/src/main/java/software/uncharted/terarium/hmiserver/service/tasks/EquationsFromImageResponseHandler.java b/packages/server/src/main/java/software/uncharted/terarium/hmiserver/service/tasks/EquationsFromImageResponseHandler.java index 012d7d19a7..8491e1f288 100644 --- a/packages/server/src/main/java/software/uncharted/terarium/hmiserver/service/tasks/EquationsFromImageResponseHandler.java +++ b/packages/server/src/main/java/software/uncharted/terarium/hmiserver/service/tasks/EquationsFromImageResponseHandler.java @@ -98,9 +98,9 @@ public TaskResponse onSuccess(final TaskResponse resp) { provenanceService.createProvenance( new Provenance() .setLeft(props.getDocumentId()) - .setLeftType(ProvenanceType.DOCUMENT) + .setLeftType(ProvenanceType.EQUATION) .setRight(props.getDocumentId()) - .setRightType(ProvenanceType.EQUATION) + .setRightType(ProvenanceType.DOCUMENT) .setRelationType(ProvenanceRelationType.EXTRACTED_FROM) ); } catch (final Exception e) { diff --git a/packages/server/src/main/java/software/uncharted/terarium/hmiserver/service/tasks/InterventionsFromDocumentResponseHandler.java b/packages/server/src/main/java/software/uncharted/terarium/hmiserver/service/tasks/InterventionsFromDocumentResponseHandler.java new file mode 100644 index 0000000000..d24a440fb3 --- /dev/null +++ b/packages/server/src/main/java/software/uncharted/terarium/hmiserver/service/tasks/InterventionsFromDocumentResponseHandler.java @@ -0,0 +1,100 @@ +package software.uncharted.terarium.hmiserver.service.tasks; + +import com.fasterxml.jackson.annotation.JsonProperty; +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; +import java.util.UUID; +import lombok.Data; +import lombok.RequiredArgsConstructor; +import lombok.extern.slf4j.Slf4j; +import org.springframework.stereotype.Component; +import software.uncharted.terarium.hmiserver.models.dataservice.provenance.Provenance; +import software.uncharted.terarium.hmiserver.models.dataservice.provenance.ProvenanceRelationType; +import software.uncharted.terarium.hmiserver.models.dataservice.provenance.ProvenanceType; +import software.uncharted.terarium.hmiserver.models.simulationservice.interventions.InterventionPolicy; +import software.uncharted.terarium.hmiserver.models.task.TaskResponse; +import software.uncharted.terarium.hmiserver.service.data.DocumentAssetService; +import software.uncharted.terarium.hmiserver.service.data.InterventionService; +import software.uncharted.terarium.hmiserver.service.data.ProvenanceService; + +@Component +@RequiredArgsConstructor +@Slf4j +public class InterventionsFromDocumentResponseHandler extends TaskResponseHandler { + + public static final String NAME = "gollm_task:interventions_from_document"; + + private final ObjectMapper objectMapper; + private final InterventionService interventionService; + private final ProvenanceService provenanceService; + private final DocumentAssetService documentAssetService; + + @Override + public String getName() { + return NAME; + } + + @Data + public static class Input { + + @JsonProperty("research_paper") + String researchPaper; + + @JsonProperty("amr") + String amr; + } + + @Data + public static class Response { + + JsonNode response; + } + + @Data + public static class Properties { + + UUID projectId; + UUID documentId; + UUID modelId; + UUID workflowId; + UUID nodeId; + } + + @Override + public TaskResponse onSuccess(final TaskResponse resp) { + try { + final Properties props = resp.getAdditionalProperties(Properties.class); + final Response interventionPolicies = objectMapper.readValue(resp.getOutput(), Response.class); + + // For each configuration, create a new model configuration + for (final JsonNode policy : interventionPolicies.response.get("interventionPolicies")) { + final InterventionPolicy ip = objectMapper.treeToValue(policy, InterventionPolicy.class); + + if (ip.getModelId() != props.modelId) { + ip.setModelId(props.modelId); + } + + final InterventionPolicy newPolicy = interventionService.createAsset( + ip, + props.projectId, + ASSUME_WRITE_PERMISSION_ON_BEHALF_OF_USER + ); + + // add provenance + provenanceService.createProvenance( + new Provenance() + .setLeft(newPolicy.getId()) + .setLeftType(ProvenanceType.INTERVENTION_POLICY) + .setRight(props.documentId) + .setRightType(ProvenanceType.DOCUMENT) + .setRelationType(ProvenanceRelationType.EXTRACTED_FROM) + ); + } + } catch (final Exception e) { + log.error("Failed to extract intervention policy", e); + throw new RuntimeException(e); + } + log.info("Intervention policy extracted successfully"); + return resp; + } +} diff --git a/packages/server/src/main/java/software/uncharted/terarium/hmiserver/utils/rebac/ReBACService.java b/packages/server/src/main/java/software/uncharted/terarium/hmiserver/utils/rebac/ReBACService.java index f212699058..d89422b7b6 100644 --- a/packages/server/src/main/java/software/uncharted/terarium/hmiserver/utils/rebac/ReBACService.java +++ b/packages/server/src/main/java/software/uncharted/terarium/hmiserver/utils/rebac/ReBACService.java @@ -90,24 +90,23 @@ private String getKeycloakBearerToken() { return "Bearer " + keycloak.tokenManager().getAccessTokenString(); } - private class CacheKey { + private static class CacheKey { SchemaObject who; Schema.Permission permission; SchemaObject what; - CacheKey(SchemaObject who, Schema.Permission permission, SchemaObject what) { + CacheKey(final SchemaObject who, final Schema.Permission permission, final SchemaObject what) { this.who = who; this.permission = permission; this.what = what; } @Override - public boolean equals(Object o) { - if (!(o instanceof CacheKey)) { + public boolean equals(final Object o) { + if (!(o instanceof final CacheKey other)) { return false; } - CacheKey other = (CacheKey) o; return who.equals(other.who) && permission == other.permission && what.equals(other.what); } @@ -121,13 +120,13 @@ public int hashCode() { .expireAfterWrite(5, TimeUnit.MINUTES) .recordStats() .removalListener((Object key, Object value, RemovalCause cause) -> log.trace("Key {} was removed {}", key, cause)) - .build(); + .build(); private final Cache userCache = Caffeine.newBuilder() .expireAfterWrite(15, TimeUnit.MINUTES) .recordStats() .removalListener((Object key, Object value, RemovalCause cause) -> log.trace("Key {} was removed {}", key, cause)) - .build(); + .build(); @PostConstruct void startup() throws Exception { @@ -281,7 +280,7 @@ public PermissionGroup createGroup(final String name) { @Observed(name = "function_profile") public PermissionUser getUser(final String id) { @PolyNull - PermissionUser result = userCache.get(id, key_id -> { + final PermissionUser result = userCache.get(id, key_id -> { final UsersResource usersResource = keycloak.realm(REALM_NAME).users(); final UserResource userResource = usersResource.get(key_id); final UserRepresentation userRepresentation = userResource.toRepresentation(); @@ -322,7 +321,7 @@ public List getUsers() { } @PolyNull - PermissionUser user = userCache.get(userRepresentation.getId(), key_id -> { + final PermissionUser user = userCache.get(userRepresentation.getId(), key_id -> { final UserResource userResource = usersResource.get(key_id); final List roles = new ArrayList<>(); @@ -419,12 +418,12 @@ public PermissionGroup getGroup(final String id) { @Observed(name = "function_profile") public boolean can(final SchemaObject who, final Schema.Permission permission, final SchemaObject what) { @PolyNull - Boolean result = permissionCache.get(new CacheKey(who, permission, what), permissionMappingFn); + final Boolean result = permissionCache.get(new CacheKey(who, permission, what), permissionMappingFn); log.trace("Cache hit: {}, miss: {}", permissionCache.stats().hitCount(), permissionCache.stats().missCount()); return result; } - private Function permissionMappingFn = key -> { + private final Function permissionMappingFn = key -> { final ReBACFunctions rebac = new ReBACFunctions(channel, spiceDbBearerToken); try { if (SPICEDB_LAUNCHMODE.equals("TEST")) { @@ -440,7 +439,10 @@ public boolean can(final SchemaObject who, final Schema.Permission permission, f @Observed(name = "function_profile") public boolean isMemberOf(final SchemaObject who, final SchemaObject what) throws Exception { @PolyNull - Boolean result = permissionCache.get(new CacheKey(who, Schema.Permission.MEMBERSHIP, what), permissionMappingFn); + final Boolean result = permissionCache.get( + new CacheKey(who, Schema.Permission.MEMBERSHIP, what), + permissionMappingFn + ); log.trace("Cache hit: {}, miss: {}", permissionCache.stats().hitCount(), permissionCache.stats().missCount()); return result; } @@ -451,7 +453,7 @@ public boolean isCreator(final SchemaObject who, final SchemaObject what) throws return rebac.hasRelationship(who, Schema.Relationship.CREATOR, what, getCurrentConsistency()); } - private void invalidatePermissionCache(SchemaObject who, SchemaObject what) { + private void invalidatePermissionCache(final SchemaObject who, final SchemaObject what) { permissionCache.invalidate(new CacheKey(who, Schema.Permission.READ, what)); permissionCache.invalidate(new CacheKey(who, Schema.Permission.WRITE, what)); permissionCache.invalidate(new CacheKey(who, Schema.Permission.MEMBERSHIP, what)); diff --git a/packages/server/src/test/java/software/uncharted/terarium/hmiserver/service/tasks/TaskServiceTest.java b/packages/server/src/test/java/software/uncharted/terarium/hmiserver/service/tasks/TaskServiceTest.java index cab1021ecf..b83caa8e3e 100644 --- a/packages/server/src/test/java/software/uncharted/terarium/hmiserver/service/tasks/TaskServiceTest.java +++ b/packages/server/src/test/java/software/uncharted/terarium/hmiserver/service/tasks/TaskServiceTest.java @@ -312,6 +312,33 @@ public void testItCanSendGoLLMConfigFromDatasetRequest() throws Exception { log.info(new String(resp.getOutput())); } + // @Test + @WithUserDetails(MockUser.URSULA) + public void testItCanSendGoLLMInterventionsFromDocumentRequest() throws Exception { + final UUID taskId = UUID.randomUUID(); + + final ClassPathResource modelResource = new ClassPathResource("gollm/SIR.json"); + final String modelContent = new String(Files.readAllBytes(modelResource.getFile().toPath())); + + final ClassPathResource documentResource = new ClassPathResource("gollm/SIR.txt"); + final String documentContent = new String(Files.readAllBytes(documentResource.getFile().toPath())); + + final InterventionsFromDocumentResponseHandler.Input input = new InterventionsFromDocumentResponseHandler.Input(); + input.setResearchPaper(documentContent); + input.setAmr(modelContent); + + final TaskRequest req = new TaskRequest(); + req.setType(TaskType.GOLLM); + req.setScript("gollm_task:interventions_from_document"); + req.setInput(input); + + final TaskResponse resp = taskService.runTaskSync(req); + + Assertions.assertEquals(taskId, resp.getId()); + + log.info(new String(resp.getOutput())); + } + // @Test @WithUserDetails(MockUser.URSULA) public void testItCanSendAmrToMmtRequest() throws Exception {