Skip to content

Commit

Permalink
fixup! discojs-core/train/models: initial
Browse files Browse the repository at this point in the history
  • Loading branch information
tharvik committed Feb 29, 2024
1 parent c48be15 commit dd8a4f6
Showing 1 changed file with 13 additions and 7 deletions.
20 changes: 13 additions & 7 deletions server/src/tasks.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ import fs from 'node:fs/promises'
import tf from '@tensorflow/tfjs'
import '@tensorflow/tfjs-node'

import type { Model, Task, Path, Digest, TaskProvider } from '@epfml/discojs-core'
import { isTaskProvider, defaultTasks, models, serialization } from '@epfml/discojs-core'
import type { Task, Path, Digest, TaskProvider } from '@epfml/discojs-core'
import { Model, isTaskProvider, defaultTasks, models, serialization } from '@epfml/discojs-core'

// default tasks and added ones
// register 'taskAndModel' event to get tasks
Expand Down Expand Up @@ -97,19 +97,25 @@ export class TasksAndModels {
}

async addTaskAndModel (task: Task | TaskProvider, model?: Model | URL): Promise<void> {
let discoTask: Task
if (isTaskProvider(task)) {
task = task.getTask()
discoTask = task.getTask()
} else {
discoTask = task
}

let tfModel: Model
if (model === undefined) {
model = await this.loadModelFromTask(task)
tfModel = await this.loadModelFromTask(task)
} else if (model instanceof Model) {
tfModel = model
} else if (model instanceof URL) {
model = new models.TFJS(await tf.loadLayersModel(model.href))
tfModel = new models.TFJS(await tf.loadLayersModel(model.href))
} else {
throw new Error('invalid model')
}

this.tasksAndModels = this.tasksAndModels.add([task, model])
this.emit('taskAndModel', task, model)
this.tasksAndModels = this.tasksAndModels.add([discoTask, tfModel])
this.emit('taskAndModel', discoTask, tfModel)
}
}

0 comments on commit dd8a4f6

Please sign in to comment.