Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

configure model from dataset #2891

Merged
merged 9 commits into from
Mar 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 27 additions & 1 deletion packages/client/hmi-client/src/services/goLLM.ts
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,33 @@ export async function configureModel(documentId: string, modelId: string): Promi
}
});
} catch (err) {
logger.error(`An issue occured while exctracting a model configuration. ${err}`);
logger.error(`An issue occured while exctracting a model configuration from document. ${err}`);
}
}

export async function configureModelFromDatasets(modelId: string, datasetIds: string[]) {
try {
// FIXME: Using first dataset for now...
blanchco marked this conversation as resolved.
Show resolved Hide resolved
const response = await API.post<TaskResponse>('/gollm/configure-from-dataset', null, {
params: {
'model-id': modelId,
'dataset-ids': datasetIds[0]
blanchco marked this conversation as resolved.
Show resolved Hide resolved
}
});

const taskId = response.data.id;
await handleTaskById(taskId, {
ondata(data, closeConnection) {
if (data?.status === TaskStatus.Failed) {
throw new FatalError('Task failed');
}
if (data.status === TaskStatus.Success) {
closeConnection();
}
}
});
} catch (err) {
logger.error(`An issue occured while exctracting a model configuration from dataset. ${err}`);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,11 @@ export const ModelConfigOperation: Operation = {
displayName: 'Configure model',
description: 'Create model configurations.',
isRunnable: true,
inputs: [{ type: 'modelId' }, { type: 'documentId', label: 'Document', isOptional: true }],
inputs: [
{ type: 'modelId' },
{ type: 'documentId', label: 'Document', isOptional: true },
{ type: 'datasetId', label: 'Dataset', isOptional: true }
],
outputs: [{ type: 'modelConfigId' }],
action: async () => ({}),

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,14 @@
>({{ suggestedConfirgurationContext.tableData.length }})</span
>
<Button
outlined
label="Extract configurations from a document"
size="small"
icon="pi pi-cog"
@click.stop="extractConfigurations"
:disabled="loadingConfigs || !documentId || !model.id"
label="Extract Configurations"
blanchco marked this conversation as resolved.
Show resolved Hide resolved
outlined
@click.stop="toggleExtractionMenu"
style="margin-left: auto"
:loading="isFetchingConfigsFromDataset || isFetchingConfigsFromDocument"
/>
<Menu ref="extractionMenu" :model="menuItems" popup />
</template>

<DataTable
Expand All @@ -38,7 +38,7 @@
:rows="5"
sort-field="createdOn"
:sort-order="-1"
:loading="loadingConfigs"
:loading="isFetchingConfigsFromDocument || isFetchingConfigsFromDataset"
>
<Column field="name" header="Name" style="width: 15%">
<template #body="{ data }">
Expand Down Expand Up @@ -266,7 +266,7 @@ import { useToastService } from '@/services/toast';
import TeraOutputDropdown from '@/components/drilldown/tera-output-dropdown.vue';
import { logger } from '@/utils/logger';
import TeraModelDiagram from '@/components/model/petrinet/model-diagrams/tera-model-diagram.vue';
import { configureModel } from '@/services/goLLM';
import { configureModel, configureModelFromDatasets } from '@/services/goLLM';
import DataTable from 'primevue/datatable';
import Column from 'primevue/column';
import TeraNotebookJupyterInput from '@/components/llm/tera-notebook-jupyter-input.vue';
Expand All @@ -279,6 +279,8 @@ import LoadingWateringCan from '@/assets/images/lottie-loading-wateringCan.json'
import EmptySeed from '@/assets/images/lottie-empty-seed.json';
import { Vue3Lottie } from 'vue3-lottie';
import TeraModelSemanticTables from '@/components/model/petrinet/tera-model-semantic-tables.vue';
import Menu from 'primevue/menu';
import { MenuItem } from 'primevue/menuitem';
import { ModelConfigOperation, ModelConfigOperationState } from './model-config-operation';
import TeraModelConfigTable from './tera-model-config-table.vue';

Expand Down Expand Up @@ -396,6 +398,43 @@ const runFromCode = () => {
const edges = computed(() => modelConfiguration?.value?.configuration?.model?.edges ?? []);
const vertices = computed(() => modelConfiguration?.value?.configuration.model?.vertices ?? []);

const extractionMenu = ref();
const toggleExtractionMenu = (event) => {
extractionMenu.value.toggle(event);
};
const menuItems = computed<MenuItem[]>(() => {
const items: MenuItem[] = [];
if (documentId.value) {
items.push({
label: 'From a document',
command: () => {
extractConfigurationsFromDocument();
}
});
}

if (datasetId.value) {
items.push({
label: 'From a dataset',
command: () => {
extractConfigurationsFromDataset();
}
});
}

if (documentId.value && datasetId.value) {
items.push({
label: 'From both',
command: () => {
extractConfigurationsFromDataset();
extractConfigurationsFromDocument();
}
});
}

return items;
});

// FIXME: Copy pasted in 3 locations, could be written cleaner and in a service
const saveCodeToState = (code: string, hasCodeBeenRun: boolean) => {
const state = _.cloneDeep(props.node.state);
Expand Down Expand Up @@ -435,6 +474,7 @@ const selectedConfigId = computed(
);

const documentId = computed(() => props.node.inputs?.[1]?.value?.[0]?.documentId);
const datasetId = computed(() => props.node.inputs?.[2]?.value?.[0]);

const suggestedConfirgurationContext = ref<{
isOpen: boolean;
Expand All @@ -446,8 +486,9 @@ const suggestedConfirgurationContext = ref<{
modelConfiguration: null
});

const loadingConfigs = ref(false);
const model = ref<Model | null>();
const isFetchingConfigsFromDocument = ref(false);
const isFetchingConfigsFromDataset = ref(false);
const model = ref<Model | null>(null);

const modelConfiguration = computed<ModelConfiguration | null>(() => {
if (!model.value) return null;
Expand Down Expand Up @@ -697,9 +738,12 @@ const onSelection = (id: string) => {

const fetchConfigurations = async (modelId: string) => {
if (modelId) {
loadingConfigs.value = true;
suggestedConfirgurationContext.value.tableData = await getModelConfigurations(modelId);
loadingConfigs.value = false;
// FIXME: since configurations are made on the backend on the fly, we need to wait for the db to update before fetching, here's an artificaial delay
blanchco marked this conversation as resolved.
Show resolved Hide resolved
setTimeout(async () => {
isFetchingConfigsFromDocument.value = true;
suggestedConfirgurationContext.value.tableData = await getModelConfigurations(modelId);
isFetchingConfigsFromDocument.value = false;
}, 800);
}
};

Expand Down Expand Up @@ -788,11 +832,19 @@ const useSuggestedConfig = (config: ModelConfiguration) => {
logger.success(`Configuration applied ${config.name}`);
};

const extractConfigurations = async () => {
const extractConfigurationsFromDocument = async () => {
if (!documentId.value || !model.value?.id) return;
loadingConfigs.value = true;
isFetchingConfigsFromDocument.value = true;
await configureModel(documentId.value, model.value.id);
loadingConfigs.value = false;
isFetchingConfigsFromDocument.value = false;
fetchConfigurations(model.value.id);
};

const extractConfigurationsFromDataset = async () => {
if (!datasetId.value || !model.value?.id) return;
isFetchingConfigsFromDataset.value = true;
await configureModelFromDatasets(model.value.id, [datasetId.value]);
isFetchingConfigsFromDataset.value = false;
fetchConfigurations(model.value.id);
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,15 @@
import org.springframework.web.bind.annotation.*;
import org.springframework.web.server.ResponseStatusException;
import software.uncharted.terarium.hmiserver.models.dataservice.ResponseDeleted;
import software.uncharted.terarium.hmiserver.models.dataservice.dataset.Dataset;
import software.uncharted.terarium.hmiserver.models.dataservice.document.DocumentAsset;
import software.uncharted.terarium.hmiserver.models.dataservice.model.Model;
import software.uncharted.terarium.hmiserver.models.dataservice.model.ModelConfiguration;
import software.uncharted.terarium.hmiserver.models.dataservice.model.ModelDescription;
import software.uncharted.terarium.hmiserver.models.dataservice.provenance.ProvenanceQueryParam;
import software.uncharted.terarium.hmiserver.models.dataservice.provenance.ProvenanceType;
import software.uncharted.terarium.hmiserver.security.Roles;
import software.uncharted.terarium.hmiserver.service.data.DatasetService;
import software.uncharted.terarium.hmiserver.service.data.DocumentAssetService;
import software.uncharted.terarium.hmiserver.service.data.ModelService;
import software.uncharted.terarium.hmiserver.service.data.ProvenanceSearchService;
Expand All @@ -46,6 +48,8 @@ public class ModelController {

final ObjectMapper objectMapper;

final DatasetService datasetService;

@GetMapping("/descriptions")
@Secured(Roles.USER)
@Operation(summary = "Gets all model descriptions")
Expand Down Expand Up @@ -275,11 +279,11 @@ ResponseEntity<List<ModelConfiguration>> getModelConfigurationsForModelId(


// Find the Document Assets linked via provenance to the model configuration
final ProvenanceQueryParam body = new ProvenanceQueryParam();
body.setRootId(config.getId());
body.setRootType(ProvenanceType.MODEL_CONFIGURATION);
body.setTypes(List.of(ProvenanceType.DOCUMENT));
final Set<String> documentIds = provenanceSearchService.modelConfigFromDocument(body);
final ProvenanceQueryParam documentSearchBody = new ProvenanceQueryParam();
blanchco marked this conversation as resolved.
Show resolved Hide resolved
documentSearchBody.setRootId(config.getId());
documentSearchBody.setRootType(ProvenanceType.MODEL_CONFIGURATION);
documentSearchBody.setTypes(List.of(ProvenanceType.DOCUMENT));
final Set<String> documentIds = provenanceSearchService.modelConfigFromDocument(documentSearchBody);

List<String> documentSourceNames = new ArrayList<String>();
documentIds.forEach(documentId -> {
Expand All @@ -295,17 +299,45 @@ ResponseEntity<List<ModelConfiguration>> getModelConfigurationsForModelId(
log.error("Unable to get the document " + documentId, e);
}
});

// Find the Dataset Assets linked via provenance to the model configuration
final ProvenanceQueryParam datasetSearchBody = new ProvenanceQueryParam();
datasetSearchBody.setRootId(config.getId());
datasetSearchBody.setRootType(ProvenanceType.MODEL_CONFIGURATION);
datasetSearchBody.setTypes(List.of(ProvenanceType.DATASET));
final Set<String> datasetIds = provenanceSearchService.modelConfigFromDataset(datasetSearchBody);
blanchco marked this conversation as resolved.
Show resolved Hide resolved

List<String> datasetSourceNames = new ArrayList<String>();
datasetIds.forEach(datasetId -> {
try {
// Fetch the Document extractions
final Optional<Dataset> dataset = datasetService
.getAsset(UUID.fromString(datasetId));
if (dataset.isPresent()) {
final String name = dataset.get().getName();
documentSourceNames.add(name);
}
} catch (final Exception e) {
log.error("Unable to get the document " + datasetId, e);
}
});


List<String> sourceNames = new ArrayList<String>();
sourceNames.addAll(documentSourceNames);
sourceNames.addAll(datasetSourceNames);

final ObjectNode metadata = (ObjectNode) configuration.get("metadata");

metadata.set("source", objectMapper.valueToTree(documentSourceNames));
metadata.set("source", objectMapper.valueToTree(sourceNames));

((ObjectNode) configuration).set("metadata", metadata);

config.setConfiguration(configuration);
});

return ResponseEntity.ok(modelConfigurations);
} catch (final IOException e) {
} catch (final Exception e) {
final String error = "Unable to get model configurations";
log.error(error, e);
throw new ResponseStatusException(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ public ResponseEntity<TaskResponse> createConfigureModelTask(
})
public ResponseEntity<TaskResponse> createConfigFromDatasetTask(
@RequestParam(name = "model-id", required = true) final UUID modelId,
@RequestParam(name = "document-ids", required = true) final List<UUID> datasetIds) {
@RequestParam(name = "dataset-ids", required = true) final List<UUID> datasetIds) {

try {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -280,10 +280,38 @@ public Set<String> modelConfigFromDocument(final ProvenanceQueryParam payload) {
}

try (final Session session = neo4jService.getSession()) {
final UUID modelId = payload.getRootId();
final UUID modelConfigurationId = payload.getRootId();

final String query = String.format("MATCH (d:Document)<-[r:EXTRACTED_FROM]-(m:ModelConfiguration {id: '%s'}) RETURN d",
modelId);
modelConfigurationId);

final Result response = session.run(query);
final Set<String> responseData = new HashSet<>();
while (response.hasNext()) {
responseData.add(response.next().get("d").get("id").asString());
}
return responseData;
}
}

/**
*
* Identifies the dataset from which a model configuration was extracted
*
* @param payload - Search param payload.
* @return
*/
public Set<String> modelConfigFromDataset(final ProvenanceQueryParam payload) {
if (payload.getRootType() != ProvenanceType.MODEL_CONFIGURATION) {
throw new IllegalArgumentException(
"Dataset used for model-configuration extraction can only be found by providing a model-confirguration");
}

try (final Session session = neo4jService.getSession()) {
final UUID modelConfigurationId = payload.getRootId();

final String query = String.format("MATCH (d:Dataset)<-[r:EXTRACTED_FROM]-(m:ModelConfiguration {id: '%s'}) RETURN d",
modelConfigurationId);

final Result response = session.run(query);
final Set<String> responseData = new HashSet<>();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import org.springframework.stereotype.Component;

import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;

Expand Down Expand Up @@ -70,26 +71,28 @@ public void onSuccess(TaskResponse resp) {
.orElseThrow();
final Response configurations = objectMapper.readValue(((TaskResponse) resp).getOutput(),
Response.class);

// For each configuration, create a new model configuration with parameters set
configurations.response.get("conditions").forEach((condition) -> {
// Map the parameters values to the model
final Model modelCopy = new Model(model);
final List<ModelParameter> modelParameters = modelCopy.getSemantics().getOde().getParameters();
modelParameters.forEach((parameter) -> {
JsonNode conditionParameters = condition.get("parameters");
conditionParameters.forEach((conditionParameter) -> {
if (parameter.getId().equals(conditionParameter.get("id").asText())) {
parameter.setValue(conditionParameter.get("value").doubleValue());
}
});
});
// Map the parameters values to the model
final Model modelCopy = new Model(model);
List<ModelParameter> modelParameters;
if(modelCopy.getHeader().getSchemaName().toLowerCase().equals("regnet")) {
blanchco marked this conversation as resolved.
Show resolved Hide resolved
modelParameters = objectMapper.convertValue(modelCopy.getModel().get("parameters"), new TypeReference<List<ModelParameter>>() {});
} else {
modelParameters = modelCopy.getSemantics().getOde().getParameters();
}
modelParameters.forEach((parameter) -> {
JsonNode conditionParameters = configurations.getResponse().get("parameters");
conditionParameters.forEach((conditionParameter) -> {
if (parameter.getId().equals(conditionParameter.get("id").asText())) {
parameter.setValue(conditionParameter.get("value").doubleValue());
}
});
});

// Create the new configuration
final ModelConfiguration configuration = new ModelConfiguration();
configuration.setModelId(model.getId());
configuration.setName(condition.get("name").asText());
configuration.setDescription(condition.get("description").asText());
configuration.setName("New configuration from dataset");
configuration.setDescription("");
configuration.setConfiguration(modelCopy);

try {
Expand All @@ -107,7 +110,7 @@ public void onSuccess(TaskResponse resp) {
} catch (IOException e) {
log.error("Failed to set model configuration", e);
}
});


} catch (Exception e) {
log.error("Failed to configure model", e);
Expand Down
Loading
Loading