Skip to content

Commit

Permalink
server: use node:fs/promises
Browse files Browse the repository at this point in the history
  • Loading branch information
tharvik committed Feb 29, 2024
1 parent bde41e7 commit 0bc8781
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 31 deletions.
56 changes: 26 additions & 30 deletions server/src/tasks.ts
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import { List, Set } from 'immutable'
import { createHash } from 'node:crypto'
import fs from 'node:fs'
import fs from 'node:fs/promises'
import tf from '@tensorflow/tfjs'
import '@tensorflow/tfjs-node'

import type { Task, Path, Digest, TaskProvider } from '@epfml/discojs-core'
import { isTaskProvider, defaultTasks } from '@epfml/discojs-core'
import { isTaskProvider, defaultTasks, serialization } from '@epfml/discojs-core'

// default tasks and added ones
// register 'taskAndModel' event to get tasks
Expand Down Expand Up @@ -36,40 +36,43 @@ export class TasksAndModels {
let model: tf.LayersModel | undefined

const modelPath = `./models/${discoTask.id}/`
if (fs.existsSync(modelPath)) {
// either a model has already been trained, or the pretrained
// model has already been downloaded
return await tf.loadLayersModel(`file://${modelPath}/model.json`)
try {
const content = await fs.readFile(`${modelPath}/model.json`)
return await serialization.model.decode(content)
} catch {
// unable to read file, continuing
}

const modelURL = discoTask.trainingInformation.modelURL
if (modelURL !== undefined) {
model = await tf.loadLayersModel(modelURL)
} else if (isTaskProvider(task)) {
model = await task.getModel()
} else {
const modelURL = discoTask.trainingInformation.modelURL
if (modelURL !== undefined) {
model = await tf.loadLayersModel(modelURL)
} else if (isTaskProvider(task)) {
model = await task.getModel()
} else {
throw new Error('model not provided in task definition')
}
throw new Error('model not provided in task definition')
}

fs.mkdirSync(modelPath, { recursive: true })
await model.save(`file://${modelPath}`, { includeOptimizer: true })
await fs.mkdir(modelPath, { recursive: true })
const encoded = await serialization.model.encode(model)
await fs.writeFile(`${modelPath}/model.json`, encoded)

// Check digest if provided
if (discoTask.digest !== undefined) {
try {
this.checkDigest(discoTask.digest, modelPath)
await this.checkDigest(discoTask.digest, modelPath)
} catch (e) {
TasksAndModels.removeModelFiles(modelPath)
console.warn('removing nodel files at', modelPath)
await fs.rm(modelPath, { recursive: true, force: true })
throw e
}
}

return model
}

private checkDigest (digest: Digest, modelPath: Path): void {
private async checkDigest (digest: Digest, modelPath: Path): Promise<void> {
const hash = createHash(digest.algorithm)
const modelConfigRaw = fs.readFileSync(`${modelPath}/model.json`)
const modelConfigRaw = await fs.readFile(`${modelPath}/model.json`)

const modelConfig = JSON.parse(modelConfigRaw.toString())
const weightsFiles = modelConfig.weightsManifest[0].paths
Expand All @@ -79,10 +82,10 @@ export class TasksAndModels {
)) {
throw new Error()
}
weightsFiles.forEach((file: string) => {
const data = fs.readFileSync(`${modelPath}/${file}`)
await Promise.all(weightsFiles.map(async (file: string) => {
const data = await fs.readFile(`${modelPath}/${file}`)
hash.update(data)
})
}))

const computedDigest = hash.digest('base64')
if (computedDigest !== digest.value) {
Expand Down Expand Up @@ -116,11 +119,4 @@ export class TasksAndModels {
this.tasksAndModels = this.tasksAndModels.add([discoTask, tfModel])
this.emit('taskAndModel', discoTask, tfModel)
}

static removeModelFiles (path: Path): void {
console.warn('removing nodel files at', path)
fs.rm(path, { recursive: true, force: true }, (err) => {
if (err !== null) console.error(err)
})
}
}
2 changes: 1 addition & 1 deletion server/tests/e2e/federated.spec.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import fs from 'fs/promises'
import fs from 'node:fs/promises'
import path from 'node:path'
import type { Server } from 'node:http'
import { Range } from 'immutable'
Expand Down

0 comments on commit 0bc8781

Please sign in to comment.