Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

model card on document upload #2693

Merged
merged 4 commits into from
Feb 12, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ import InputText from 'primevue/inputtext';
import { useToastService } from '@/services/toast';
import TeraImportGithubFile from '@/components/widgets/tera-import-github-file.vue';
import { extractPDF } from '@/services/knowledge';
import { modelCard } from '@/services/goLLM';

defineProps<{
visible: boolean;
Expand Down Expand Up @@ -153,6 +154,11 @@ async function upload() {
const newAsset = useProjects().addAsset(assetType, id);
if (name && name.toLowerCase().endsWith('.pdf')) {
extractPDF(id);
} else if (
(name && name.toLowerCase().endsWith('.txt')) ||
(name && name.toLowerCase().endsWith('.md'))
) {
modelCard(id);
YohannParis marked this conversation as resolved.
Show resolved Hide resolved
}
return newAsset;
})
Expand Down
21 changes: 16 additions & 5 deletions packages/client/hmi-client/src/services/document-assets.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@
* Documents Asset
*/

import API from '@/api/api';
import API, { PollerState } from '@/api/api';
import type { AddDocumentAssetFromXDDResponse, Document, DocumentAsset } from '@/types/Types';
import { logger } from '@/utils/logger';
import { Ref } from 'vue';
import { fetchExtraction } from './knowledge';
import { modelCard } from './goLLM';

/**
* Get all documents
Expand Down Expand Up @@ -209,16 +211,25 @@ async function createDocumentFromXDD(
projectId: string
): Promise<AddDocumentAssetFromXDDResponse | null> {
if (!document || !projectId) return null;
const response = await API.post(`/document-asset/create-document-from-xdd`, {
document,
projectId
});
const response = await API.post<AddDocumentAssetFromXDDResponse>(
`/document-asset/create-document-from-xdd`,
{
document,
projectId
}
);

if (!response || response.status >= 400) {
logger.error('Error upload file from doi');
return null;
}

if (response.data.extractionJobId) {
const result = await fetchExtraction(response.data.extractionJobId);
if (result.state === PollerState.Done) {
modelCard(response.data.documentAssetId);
}
}
YohannParis marked this conversation as resolved.
Show resolved Hide resolved
return response.data;
}
export {
Expand Down
5 changes: 2 additions & 3 deletions packages/client/hmi-client/src/services/goLLM.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,11 @@ import { logger } from '@/utils/logger';
* @param {string} documentId - The document ID.
* @param {string} modelId - The model ID.
*/
export async function modelCard(documentId: string, modelId: string): Promise<void> {
export async function modelCard(documentId: string): Promise<void> {
try {
const response = await API.post<TaskResponse>('/gollm/model-card', null, {
params: {
'document-id': documentId,
'model-id': modelId
'document-id': documentId
}
});

Expand Down
2 changes: 2 additions & 0 deletions packages/client/hmi-client/src/services/knowledge.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ import API, { Poller, PollerState, PollResponse, PollerResult } from '@/api/api'
import { AxiosError, AxiosResponse } from 'axios';
import type { Code, Dataset, ExtractionResponse, Model } from '@/types/Types';
import { logger } from '@/utils/logger';
import { modelCard } from './goLLM';

/**
* Fetch information from the extraction service via the Poller utility
Expand Down Expand Up @@ -192,6 +193,7 @@ export const extractPDF = async (documentId: string) => {
if (resp) {
const pollResult = await fetchExtraction(resp);
if (pollResult?.state === PollerState.Done) {
modelCard(documentId);
pdfExtractions(documentId, Extractor.SKEMA);
pdfExtractions(documentId, Extractor.MIT);
}
Expand Down
2 changes: 1 addition & 1 deletion packages/client/hmi-client/src/services/model.ts
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,6 @@ export async function generateModelCard(
}

if (modelServiceType === ModelServiceType.TA4) {
await modelCard(documentId, modelId);
await modelCard(documentId);
blanchco marked this conversation as resolved.
Show resolved Hide resolved
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ import 'ace-builds/src-noconflict/mode-python';
import 'ace-builds/src-noconflict/mode-julia';
import 'ace-builds/src-noconflict/mode-r';
import { AssetType, ProgrammingLanguage } from '@/types/Types';
import type { Card, Code, Model } from '@/types/Types';
import type { Card, Code, DocumentAsset, Model } from '@/types/Types';
import { AssetBlock, WorkflowNode, WorkflowOutput } from '@/types/workflow';
import { KernelSessionManager } from '@/services/jupyter';
import { logger } from '@/utils/logger';
Expand All @@ -173,6 +173,7 @@ import TeraModelCard from '@/components/model/petrinet/tera-model-card.vue';
import TeraOutputDropdown from '@/components/drilldown/tera-output-dropdown.vue';
import { ModelServiceType } from '@/types/common';
import { extensionFromProgrammingLanguage } from '@/utils/data-util';
import { getDocumentAsset } from '@/services/document-assets';
import { ModelFromCodeState } from './model-from-code-operation';

const props = defineProps<{
Expand Down Expand Up @@ -204,6 +205,10 @@ const kernelManager = new KernelSessionManager();
const selectedModel = ref<Model | null>(null);
const documentId = computed(() => props.node.inputs?.[1]?.value?.[0]);

const document = ref<DocumentAsset | null>(null);

const goLLMCard = computed<any>(() => document.value?.metadata?.gollm_card);

const inputCodeBlocks = ref<AssetBlock<CodeBlock>[]>([]);

const allCodeBlocks = computed<AssetBlock<CodeBlock>[]>(() => {
Expand Down Expand Up @@ -276,14 +281,18 @@ const selectedOutput = computed<WorkflowOutput<ModelFromCodeState> | undefined>(
);

const card = ref<Card | null>(null);
const goLLMCard = ref<any>(null);

onMounted(async () => {
clonedState.value = cloneDeep(props.node.state);

if (selectedOutputId.value) {
onUpdateOutput(selectedOutputId.value);
}

if (documentId.value) {
document.value = await getDocumentAsset(documentId.value);
}

fetchingInputBlocks.value = true;
await getInputCodeBlocks();
fetchingInputBlocks.value = false;
Expand Down Expand Up @@ -501,6 +510,7 @@ function isSaveModelDisabled(): boolean {
return !selectedModel.value || !!activeProjectModelIds?.includes(selectedModel.value.id);
}

// generates the model card and fetches the model when finished
async function generateCard(docId, modelId) {
if (!docId || !modelId) return;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ const assetLoading = ref(false);
const loadingModel = ref(false);
const selectedModel = ref<Model | null>(null);
const card = ref<Card | null>(null);
const goLLMCard = ref<any>(null);
const goLLMCard = computed<any>(() => document.value?.metadata?.gollm_card);

const formSteps = ref([
{
Expand Down Expand Up @@ -417,6 +417,7 @@ function removeEquation(index: number) {
emit('update-state', clonedState.value);
}

// generates the model card and fetches the model when finished
async function generateCard(docId, modelId) {
if (!docId || !modelId) return;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,13 @@
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
import software.uncharted.terarium.hmiserver.annotations.IgnoreRequestLogging;
import software.uncharted.terarium.hmiserver.models.dataservice.document.DocumentAsset;
import software.uncharted.terarium.hmiserver.models.dataservice.model.Model;
import software.uncharted.terarium.hmiserver.models.task.TaskRequest;
import software.uncharted.terarium.hmiserver.models.task.TaskResponse;
import software.uncharted.terarium.hmiserver.models.task.TaskStatus;
import software.uncharted.terarium.hmiserver.security.Roles;
import software.uncharted.terarium.hmiserver.service.TaskResponseHandler;
import software.uncharted.terarium.hmiserver.service.TaskService;
import software.uncharted.terarium.hmiserver.service.data.DocumentAssetService;
import software.uncharted.terarium.hmiserver.service.data.ModelService;

import java.io.IOException;
import java.util.Optional;
Expand All @@ -41,7 +39,6 @@ public class GoLLMController {
final private ObjectMapper objectMapper;
final private TaskService taskService;
final private DocumentAssetService documentAssetService;
final private ModelService modelService;

final private String MODEL_CARD_SCRIPT = "gollm:model_card";

Expand All @@ -58,7 +55,6 @@ private static class ModelCardResponse {

@Data
private static class ModelCardProperties {
UUID modelId;
UUID documentId;
}

Expand All @@ -73,12 +69,16 @@ private TaskResponseHandler getModelCardResponseHandler() {
try {
final String serializedString = objectMapper.writeValueAsString(resp.getAdditionalProperties());
final ModelCardProperties props = objectMapper.readValue(serializedString, ModelCardProperties.class);
log.info("Writing model card to database for model {}", props.getModelId());
final Model model = modelService.getAsset(props.getModelId())
log.info("Writing model card to database for document {}", props.getDocumentId());
final DocumentAsset document = documentAssetService.getAsset(props.getDocumentId())
.orElseThrow();
final ModelCardResponse card = objectMapper.readValue(resp.getOutput(), ModelCardResponse.class);
model.getMetadata().setGollmCard(card.response);
modelService.updateAsset(model);
if (document.getMetadata() == null){
document.setMetadata(new java.util.HashMap<>());
}
document.getMetadata().put("goLLM_card", card.response);
blanchco marked this conversation as resolved.
Show resolved Hide resolved

documentAssetService.updateAsset(document);
} catch (final IOException e) {
log.error("Failed to write model card to database", e);
}
Expand All @@ -99,16 +99,9 @@ private TaskResponseHandler getModelCardResponseHandler() {
@ApiResponse(responseCode = "500", description = "There was an issue dispatching the request", content = @Content)
})
public ResponseEntity<TaskResponse> createModelCardTask(
@RequestParam(name = "document-id", required = true) final UUID documentId,
@RequestParam(name = "model-id", required = true) final UUID modelId) {
@RequestParam(name = "document-id", required = true) final UUID documentId) {

try {
// Ensure the model is valid
final Optional<Model> model = modelService.getAsset(modelId);
if (model.isEmpty()) {
return ResponseEntity.notFound().build();
}

// Grab the document
final Optional<DocumentAsset> document = documentAssetService.getAsset(documentId);
if (document.isEmpty()) {
Expand Down Expand Up @@ -137,7 +130,6 @@ public ResponseEntity<TaskResponse> createModelCardTask(
req.setInput(objectMapper.writeValueAsBytes(input));

final ModelCardProperties props = new ModelCardProperties();
props.setModelId(modelId);
props.setDocumentId(documentId);
req.setAdditionalProperties(props);

Expand Down
Loading