Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

if we can't find a grounding, we set as a default empty grounding #5367

Merged
merged 5 commits into from
Nov 5, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion packages/client/hmi-client/src/services/goLLM.ts
YohannParis marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ export async function interventionPolicyFromDocument(

export async function enrichModelMetadata(modelId: string, documentId: string, overwrite: boolean): Promise<void> {
try {
await API.get<TaskResponse>('/gollm/enrich-model-metadata', {
await API.get('/gollm/enrich-model-metadata', {
params: {
'model-id': modelId,
'document-id': documentId,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
import software.uncharted.terarium.hmiserver.models.dataservice.dataset.Dataset;
import software.uncharted.terarium.hmiserver.models.dataservice.document.DocumentAsset;
import software.uncharted.terarium.hmiserver.models.dataservice.model.Model;
import software.uncharted.terarium.hmiserver.models.dataservice.modelparts.ModelMetadata;
import software.uncharted.terarium.hmiserver.models.task.TaskRequest;
import software.uncharted.terarium.hmiserver.models.task.TaskRequest.TaskType;
import software.uncharted.terarium.hmiserver.models.task.TaskResponse;
Expand Down Expand Up @@ -263,7 +264,7 @@ public ResponseEntity<TaskResponse> createConfigureModelFromDocumentTask(

// stripping the metadata from the model before its sent since it can cause
// gollm to fail with massive inputs
model.get().setMetadata(null);
model.get().setMetadata(new ModelMetadata());
input.setAmr(model.get().serializeWithoutTerariumFields(new String[] { "id" }, null));

// Create the task
Expand Down Expand Up @@ -391,7 +392,7 @@ public ResponseEntity<TaskResponse> createConfigureModelFromDatasetTask(
input.setDataset(dataArray);
// stripping the metadata from the model before its sent since it can cause
// gollm to fail with massive inputs
model.get().setMetadata(null);
model.get().setMetadata(new ModelMetadata());
input.setAmr(model.get().serializeWithoutTerariumFields(null, null));

// set matrix string if provided
Expand Down Expand Up @@ -506,7 +507,7 @@ public ResponseEntity<TaskResponse> createInterventionsFromDocumentTask(

// stripping the metadata from the model before its sent since it can cause
// gollm to fail with massive inputs
model.get().setMetadata(null);
model.get().setMetadata(new ModelMetadata());
input.setAmr(model.get().serializeWithoutTerariumFields(new String[] { "id" }, null));

// Create the task
Expand Down Expand Up @@ -742,7 +743,7 @@ public ResponseEntity<TaskResponse> createGenerateResponseTask(
description = "Dispatched successfully",
content = @Content(
mediaType = "application/json",
schema = @io.swagger.v3.oas.annotations.media.Schema(implementation = TaskResponse.class)
schema = @io.swagger.v3.oas.annotations.media.Schema(implementation = UUID.class)
)
),
@ApiResponse(
Expand All @@ -753,10 +754,10 @@ public ResponseEntity<TaskResponse> createGenerateResponseTask(
@ApiResponse(responseCode = "500", description = "There was an issue dispatching the request", content = @Content)
}
)
public ResponseEntity<TaskResponse> createEnrichModelMetadataTask(
public ResponseEntity<UUID> createEnrichModelMetadataTask(
@RequestParam(name = "model-id", required = true) final UUID modelId,
@RequestParam(name = "document-id", required = false) final UUID documentId,
@RequestParam(name = "mode", required = false, defaultValue = "ASYNC") final TaskMode mode,
@RequestParam(name = "mode", required = false, defaultValue = "SYNC") final TaskMode mode,
@RequestParam(name = "project-id", required = false) final UUID projectId,
@RequestParam(name = "overwrite", required = false, defaultValue = "false") final boolean overwrite
) {
Expand All @@ -765,9 +766,8 @@ public ResponseEntity<TaskResponse> createEnrichModelMetadataTask(
projectId
);

final TaskResponse resp;
try {
resp = modelService.enrichModel(projectId, Optional.of(documentId), modelId, permission, true);
modelService.enrichModel(projectId, documentId, modelId, permission, true);
} catch (final IOException e) {
log.error("An error occurred while trying to retrieve information necessary for model enrichment.", e);
throw new ResponseStatusException(HttpStatus.SERVICE_UNAVAILABLE, messages.get("postgres.service-unavailable"));
Expand All @@ -782,7 +782,14 @@ public ResponseEntity<TaskResponse> createEnrichModelMetadataTask(
throw new ResponseStatusException(HttpStatus.SERVICE_UNAVAILABLE, messages.get("task.gollm.timeout"));
}

return ResponseEntity.ok().body(resp);
// check that we have a model to return
final Optional<Model> model = modelService.getAsset(modelId, permission);
if (model.isEmpty()) {
log.error(String.format("The model %s does not exist.", modelId));
throw new ResponseStatusException(HttpStatus.BAD_REQUEST, messages.get("model.not-found"));
}

return ResponseEntity.ok().body(model.get().getId());
}

@GetMapping("/enrich-amr")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ public ResponseEntity<UUID> equationsToModel(
final Model model = modelService.createAsset(responseAMR, projectId, permission);
// enrich the model with the document
if (documentId != null) {
modelService.enrichModel(projectId, Optional.of(documentId), model.getId(), permission, true);
modelService.enrichModel(projectId, documentId, model.getId(), permission, true);
}
return ResponseEntity.ok(model.getId());
} catch (final IOException e) {
Expand All @@ -258,9 +258,7 @@ public ResponseEntity<UUID> equationsToModel(
}

// If a model id is provided, update the existing model
final Optional<Model> model;

model = modelService.getAsset(modelId, permission);
final Optional<Model> model = modelService.getAsset(modelId, permission);
if (model.isEmpty()) {
log.error(String.format("The model id %s does not exist.", modelId));
throw new ResponseStatusException(HttpStatus.BAD_REQUEST, messages.get("model.not-found"));
Expand All @@ -271,7 +269,7 @@ public ResponseEntity<UUID> equationsToModel(
modelService.updateAsset(responseAMR, projectId, permission);
// enrich the model with the document
if (documentId != null) {
modelService.enrichModel(projectId, Optional.of(documentId), responseAMR.getId(), permission, true);
modelService.enrichModel(projectId, documentId, responseAMR.getId(), permission, true);
}
return ResponseEntity.ok(model.get().getId());
} catch (final IOException e) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,11 @@
import java.util.Optional;
import org.springframework.data.domain.Page;
import org.springframework.data.domain.Pageable;
import org.springframework.data.jpa.repository.JpaRepository;
import org.springframework.data.repository.NoRepositoryBean;

@NoRepositoryBean
public interface PSCrudSoftDeleteRepository<T, ID> extends PSCrudRepository<T, ID> {
public interface PSCrudSoftDeleteRepository<T, ID> extends JpaRepository<T, ID> {
List<T> findAllByIdInAndDeletedOnIsNull(final List<ID> ids);

Optional<T> getByIdAndDeletedOnIsNull(final ID id);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import co.elastic.clients.elasticsearch.core.search.SourceFilter;
import com.fasterxml.jackson.databind.ObjectMapper;
import io.micrometer.observation.annotation.Observed;
import jakarta.persistence.EntityManager;
import jakarta.persistence.PersistenceContext;
import java.io.IOException;
import java.util.List;
import java.util.Optional;
Expand Down Expand Up @@ -262,23 +264,18 @@ public Optional<Model> updateAsset(
return updatedOptional;
}

public TaskResponse enrichModel(
public UUID enrichModel(
final UUID projectId,
final Optional<UUID> documentId,
final UUID documentId,
final UUID modelId,
final Schema.Permission permission,
final boolean overwrite
) throws IOException, ExecutionException, InterruptedException, TimeoutException {
// Grab the document
final Optional<DocumentAsset> document = documentAssetService.getAsset(documentId.get(), permission);
if (document.isEmpty()) {
String errorString = String.format("Document %s not found", documentId);
log.warn(errorString);
throw new IOException(errorString);
}
// Grab the document if it exists
final Optional<DocumentAsset> document = documentAssetService.getAsset(documentId, permission);

// make sure there is text in the document
if (document.get().getText() == null || document.get().getText().isEmpty()) {
if (document.isPresent() && (document.get().getText() == null || document.get().getText().isEmpty())) {
String errorString = String.format("Document %s has no extracted text", documentId);
log.warn(errorString);
throw new IOException(errorString);
Expand All @@ -294,7 +291,7 @@ public TaskResponse enrichModel(

// stripping the metadata from the model before its sent since it can cause
// gollm to fail with massive inputs
model.get().setMetadata(null);
model.get().setMetadata(new ModelMetadata());

final TaskRequest req;

Expand Down Expand Up @@ -323,78 +320,78 @@ public TaskResponse enrichModel(
final TaskResponse resp = taskService.runTask(TaskService.TaskMode.SYNC, req);

// at this point the initial enrichment has happened.
model = getAsset(modelId, permission);
if (model.isEmpty()) {
final Optional<Model> newModel = getAsset(modelId, permission);
if (newModel.isEmpty()) {
String errorString = String.format("Model %s not found", modelId);
log.warn(errorString);
throw new IOException(errorString);
}

// Update State Grounding
if (model.get().isRegnet()) {
List<RegNetVertex> vertices = model.get().getVerticies();
if (newModel.get().isRegnet()) {
List<RegNetVertex> vertices = newModel.get().getVerticies();
vertices.forEach(vertex -> {
if (vertex == null) {
vertex = new RegNetVertex();
}
TaskUtilities.performDKGSearchAndSetGrounding(miraProxy, vertex);
});
model.get().setVerticies(vertices);
newModel.get().setVerticies(vertices);
} else {
List<State> states = model.get().getStates();
List<State> states = newModel.get().getStates();
states.forEach(state -> {
if (state == null) {
state = new State();
}
TaskUtilities.performDKGSearchAndSetGrounding(miraProxy, state);
});
model.get().setStates(states);
newModel.get().setStates(states);
}

//Update Observable Grounding
if (model.get().getObservables() != null && !model.get().getObservables().isEmpty()) {
List<Observable> observables = model.get().getObservables();
if (newModel.get().getObservables() != null && !newModel.get().getObservables().isEmpty()) {
List<Observable> observables = newModel.get().getObservables();
observables.forEach(observable -> {
if (observable == null) {
observable = new Observable();
}
TaskUtilities.performDKGSearchAndSetGrounding(miraProxy, observable);
});
model.get().setObservables(observables);
newModel.get().setObservables(observables);
}

//Update Parameter Grounding
if (model.get().getParameters() != null && !model.get().getParameters().isEmpty()) {
List<ModelParameter> parameters = model.get().getParameters();
if (newModel.get().getParameters() != null && !newModel.get().getParameters().isEmpty()) {
List<ModelParameter> parameters = newModel.get().getParameters();
parameters.forEach(parameter -> {
if (parameter == null) {
parameter = new ModelParameter();
}
TaskUtilities.performDKGSearchAndSetGrounding(miraProxy, parameter);
});
model.get().setParameters(parameters);
newModel.get().setParameters(parameters);
}

//Update Transition Grounding
if (model.get().getTransitions() != null && !model.get().getTransitions().isEmpty()) {
List<Transition> transitions = model.get().getTransitions();
if (newModel.get().getTransitions() != null && !newModel.get().getTransitions().isEmpty()) {
List<Transition> transitions = newModel.get().getTransitions();
transitions.forEach(transition -> {
if (transition == null) {
transition = new Transition();
}
TaskUtilities.performDKGSearchAndSetGrounding(miraProxy, transition);
});
model.get().setTransitions(transitions);
newModel.get().setTransitions(transitions);
}

try {
updateAsset(model.get(), projectId, permission);
updateAsset(newModel.get(), projectId, permission);
} catch (IOException e) {
String errorString = String.format("Failed to update model %s", modelId);
log.warn(errorString);
throw new IOException(errorString);
}

return resp;
return newModel.get().getId();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ public Optional<T> deleteAsset(final UUID id, final UUID projectId, final Schema
return Optional.empty();
}
asset.get().setDeletedOn(Timestamp.from(Instant.now()));
repository.save(asset.get());
repository.saveAndFlush(asset.get());
return asset;
}

Expand All @@ -146,7 +146,7 @@ public T createAsset(final T asset, final UUID projectId, final Schema.Permissio

asset.setPublicAsset(projectService.isProjectPublic(projectId));

return repository.save(asset);
return repository.saveAndFlush(asset);
}

/**
Expand All @@ -169,7 +169,7 @@ public List<T> createAssets(final List<T> assets, final UUID projectId, final Sc
final boolean projectIsPublic = projectService.isProjectPublic(projectId);
assets.forEach(asset -> asset.setPublicAsset(projectIsPublic));

return repository.saveAll(assets);
return repository.saveAllAndFlush(assets);
}

/**
Expand Down Expand Up @@ -198,7 +198,7 @@ public Optional<T> updateAsset(final T asset, final UUID projectId, final Schema

asset.setPublicAsset(projectService.isProjectPublic(projectId));

final T updated = repository.save(asset);
final T updated = repository.saveAndFlush(asset);

// Update the related ProjectAsset
projectAssetService.updateByAsset(updated, hasWritePermission);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,8 @@ public static TaskRequest getModelCardTask(String userId, DocumentAsset document
} catch (JsonProcessingException e) {
throw new IOException("Unable to serialize document text");
}
} else {
input.setResearchPaper("");
}

// Create the task
Expand Down
11 changes: 6 additions & 5 deletions packages/server/src/main/resources/application.properties
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
########################################################################################################################
# Database configuration
########################################################################################################################
spring.datasource.url=jdbc:postgresql://10.64.22.49:5432/terarium
spring.datasource.hikari.maximum-pool-size=64
spring.datasource.initialize=false
spring.datasource.password=${terarium.db.password}
spring.datasource.url=jdbc:postgresql://10.64.22.49:5432/terarium
spring.datasource.username=${terarium.db.username}
spring.datasource.initialize=false
spring.jpa.hibernate.ddl-auto=none
spring.jpa.database-platform=org.hibernate.dialect.PostgreSQLDialect
spring.flyway.enabled=false
spring.datasource.hikari.maximum-pool-size=64
spring.jpa.database-platform=org.hibernate.dialect.PostgreSQLDialect
spring.jpa.hibernate.ddl-auto=none
spring.jpa.open-in-view=false

########################################################################################################################
# Elasticsearch configuration
Expand Down
Loading