From 0bc8781c0260d3dd7ebad9bc44808f0d939b43eb Mon Sep 17 00:00:00 2001 From: tharvik Date: Mon, 26 Feb 2024 00:44:28 +0100 Subject: [PATCH] server: use node:fs/promises --- server/src/tasks.ts | 56 ++++++++++++++---------------- server/tests/e2e/federated.spec.ts | 2 +- 2 files changed, 27 insertions(+), 31 deletions(-) diff --git a/server/src/tasks.ts b/server/src/tasks.ts index 4f1725322..93a9363e9 100644 --- a/server/src/tasks.ts +++ b/server/src/tasks.ts @@ -1,11 +1,11 @@ import { List, Set } from 'immutable' import { createHash } from 'node:crypto' -import fs from 'node:fs' +import fs from 'node:fs/promises' import tf from '@tensorflow/tfjs' import '@tensorflow/tfjs-node' import type { Task, Path, Digest, TaskProvider } from '@epfml/discojs-core' -import { isTaskProvider, defaultTasks } from '@epfml/discojs-core' +import { isTaskProvider, defaultTasks, serialization } from '@epfml/discojs-core' // default tasks and added ones // register 'taskAndModel' event to get tasks @@ -36,30 +36,33 @@ export class TasksAndModels { let model: tf.LayersModel | undefined const modelPath = `./models/${discoTask.id}/` - if (fs.existsSync(modelPath)) { - // either a model has already been trained, or the pretrained - // model has already been downloaded - return await tf.loadLayersModel(`file://${modelPath}/model.json`) + try { + const content = await fs.readFile(`${modelPath}/model.json`) + return await serialization.model.decode(content) + } catch { + // unable to read file, continuing + } + + const modelURL = discoTask.trainingInformation.modelURL + if (modelURL !== undefined) { + model = await tf.loadLayersModel(modelURL) + } else if (isTaskProvider(task)) { + model = await task.getModel() } else { - const modelURL = discoTask.trainingInformation.modelURL - if (modelURL !== undefined) { - model = await tf.loadLayersModel(modelURL) - } else if (isTaskProvider(task)) { - model = await task.getModel() - } else { - throw new Error('model not provided in task definition') - } + throw new Error('model not provided in task definition') } - fs.mkdirSync(modelPath, { recursive: true }) - await model.save(`file://${modelPath}`, { includeOptimizer: true }) + await fs.mkdir(modelPath, { recursive: true }) + const encoded = await serialization.model.encode(model) + await fs.writeFile(`${modelPath}/model.json`, encoded) // Check digest if provided if (discoTask.digest !== undefined) { try { - this.checkDigest(discoTask.digest, modelPath) + await this.checkDigest(discoTask.digest, modelPath) } catch (e) { - TasksAndModels.removeModelFiles(modelPath) + console.warn('removing nodel files at', modelPath) + await fs.rm(modelPath, { recursive: true, force: true }) throw e } } @@ -67,9 +70,9 @@ export class TasksAndModels { return model } - private checkDigest (digest: Digest, modelPath: Path): void { + private async checkDigest (digest: Digest, modelPath: Path): Promise { const hash = createHash(digest.algorithm) - const modelConfigRaw = fs.readFileSync(`${modelPath}/model.json`) + const modelConfigRaw = await fs.readFile(`${modelPath}/model.json`) const modelConfig = JSON.parse(modelConfigRaw.toString()) const weightsFiles = modelConfig.weightsManifest[0].paths @@ -79,10 +82,10 @@ export class TasksAndModels { )) { throw new Error() } - weightsFiles.forEach((file: string) => { - const data = fs.readFileSync(`${modelPath}/${file}`) + await Promise.all(weightsFiles.map(async (file: string) => { + const data = await fs.readFile(`${modelPath}/${file}`) hash.update(data) - }) + })) const computedDigest = hash.digest('base64') if (computedDigest !== digest.value) { @@ -116,11 +119,4 @@ export class TasksAndModels { this.tasksAndModels = this.tasksAndModels.add([discoTask, tfModel]) this.emit('taskAndModel', discoTask, tfModel) } - - static removeModelFiles (path: Path): void { - console.warn('removing nodel files at', path) - fs.rm(path, { recursive: true, force: true }, (err) => { - if (err !== null) console.error(err) - }) - } } diff --git a/server/tests/e2e/federated.spec.ts b/server/tests/e2e/federated.spec.ts index 095baf2e7..91c6db3ce 100644 --- a/server/tests/e2e/federated.spec.ts +++ b/server/tests/e2e/federated.spec.ts @@ -1,4 +1,4 @@ -import fs from 'fs/promises' +import fs from 'node:fs/promises' import path from 'node:path' import type { Server } from 'node:http' import { Range } from 'immutable'