diff --git a/discojs/discojs-core/src/aggregator/base.ts b/discojs/discojs-core/src/aggregator/base.ts index 5816fe84d..db5b2e93b 100644 --- a/discojs/discojs-core/src/aggregator/base.ts +++ b/discojs/discojs-core/src/aggregator/base.ts @@ -1,7 +1,6 @@ import { Map, Set } from 'immutable' -import type tf from '@tensorflow/tfjs' -import type { client, Task, AsyncInformant } from '..' +import type { client, Model, Task, AsyncInformant } from '..' import { EventEmitter } from '../utils/event_emitter' @@ -60,9 +59,9 @@ export abstract class Base { */ public readonly task: Task, /** - * The TF.js model whose weights are updated on aggregation. + * The Model whose weights are updated on aggregation. */ - protected _model?: tf.LayersModel, + protected _model?: Model, /** * The round cut-off for contributions. */ @@ -141,7 +140,7 @@ export abstract class Base { * Sets the aggregator's TF.js model. * @param model The new TF.js model */ - setModel (model: tf.LayersModel): void { + setModel (model: Model): void { this._model = model } @@ -267,7 +266,7 @@ export abstract class Base { /** * The aggregator's current model. */ - get model (): tf.LayersModel | undefined { + get model (): Model | undefined { return this._model } diff --git a/discojs/discojs-core/src/aggregator/mean.spec.ts b/discojs/discojs-core/src/aggregator/mean.spec.ts index 5cbbc0da5..ffd00afb5 100644 --- a/discojs/discojs-core/src/aggregator/mean.spec.ts +++ b/discojs/discojs-core/src/aggregator/mean.spec.ts @@ -1,8 +1,7 @@ import { assert, expect } from 'chai' import type { Map } from 'immutable' -import type tf from '@tensorflow/tfjs' -import type { client, Task } from '..' +import type { client, Model, Task } from '..' import { aggregator, defaultTasks } from '..' import { AggregationStep } from './base' @@ -16,7 +15,7 @@ const bufferCapacity = weights.length export class MockMeanAggregator extends aggregator.AggregatorBase { constructor ( task: Task, - model: tf.LayersModel, + model: Model, private readonly threshold: number, roundCutoff = 0 ) { diff --git a/discojs/discojs-core/src/aggregator/mean.ts b/discojs/discojs-core/src/aggregator/mean.ts index c964feab9..5d94baede 100644 --- a/discojs/discojs-core/src/aggregator/mean.ts +++ b/discojs/discojs-core/src/aggregator/mean.ts @@ -1,8 +1,7 @@ import type { Map } from 'immutable' -import type tf from '@tensorflow/tfjs' import { AggregationStep, Base as Aggregator } from './base' -import type { Task, WeightsContainer, client } from '..' +import type { Model, Task, WeightsContainer, client } from '..' import { aggregation } from '..' /** @@ -18,7 +17,7 @@ export class MeanAggregator extends Aggregator { constructor ( task: Task, - model?: tf.LayersModel, + model?: Model, roundCutoff = 0, threshold = 1 ) { @@ -69,7 +68,9 @@ export class MeanAggregator extends Aggregator { aggregate (): void { this.log(AggregationStep.AGGREGATE) const result = aggregation.avg(this.contributions.get(0)?.values() as Iterable) - this.model?.setWeights(result.weights) + if (this.model !== undefined) { + this.model.weights = result + } this.emit(result) } diff --git a/discojs/discojs-core/src/aggregator/secure.ts b/discojs/discojs-core/src/aggregator/secure.ts index a84111ea5..5b8558580 100644 --- a/discojs/discojs-core/src/aggregator/secure.ts +++ b/discojs/discojs-core/src/aggregator/secure.ts @@ -3,7 +3,7 @@ import { Map, List, Range } from 'immutable' import tf from '@tensorflow/tfjs' import { AggregationStep, Base as Aggregator } from './base' -import type { Task, WeightsContainer, client } from '..' +import type { Model, Task, WeightsContainer, client } from '..' import { aggregation } from '..' /** @@ -20,7 +20,7 @@ export class SecureAggregator extends Aggregator { constructor ( task: Task, - model?: tf.LayersModel + model?: Model ) { super(task, model, 0, 2) @@ -36,7 +36,9 @@ export class SecureAggregator extends Aggregator { } else if (this.communicationRound === 1) { // Average the received partial sums const result = aggregation.avg(this.contributions.get(1)?.values() as Iterable) - this.model?.setWeights(result.weights) + if (this.model !== undefined) { + this.model.weights = result + } this.emit(result) } else { throw new Error('communication round is out of bounds') diff --git a/discojs/discojs-core/src/client/base.ts b/discojs/discojs-core/src/client/base.ts index ee132eebb..89dd6854d 100644 --- a/discojs/discojs-core/src/client/base.ts +++ b/discojs/discojs-core/src/client/base.ts @@ -1,8 +1,7 @@ import axios from 'axios' import type { Set } from 'immutable' -import type tf from '@tensorflow/tfjs' -import type { Task, TrainingInformant, WeightsContainer } from '..' +import type { Model, Task, TrainingInformant, WeightsContainer } from '..' import { serialization } from '..' import type { NodeID } from './types' import type { EventConnection } from './event_connection' @@ -55,7 +54,7 @@ export abstract class Base { * Fetches the latest model available on the network's server, for the adequate task. * @returns The latest model */ - async getLatestModel (): Promise { + async getLatestModel (): Promise { const url = new URL('', this.url.href) if (!url.pathname.endsWith('/')) { url.pathname += '/' diff --git a/discojs/discojs-core/src/default_tasks/cifar10.ts b/discojs/discojs-core/src/default_tasks/cifar10.ts index bc280ea84..45bab46b6 100644 --- a/discojs/discojs-core/src/default_tasks/cifar10.ts +++ b/discojs/discojs-core/src/default_tasks/cifar10.ts @@ -1,7 +1,7 @@ import tf from '@tensorflow/tfjs' -import type { Task, TaskProvider } from '..' -import { data } from '..' +import type { Model, Task, TaskProvider } from '..' +import { data, models } from '..' export const cifar10: TaskProvider = { getTask (): Task { @@ -40,7 +40,7 @@ export const cifar10: TaskProvider = { } }, - async getModel (): Promise { + async getModel (): Promise { const mobilenet = await tf.loadLayersModel( 'https://storage.googleapis.com/tfjs-models/tfjs/mobilenet_v1_0.25_224/model.json' ) @@ -61,6 +61,6 @@ export const cifar10: TaskProvider = { metrics: ['accuracy'] }) - return model + return new models.TFJS(model) } } diff --git a/discojs/discojs-core/src/default_tasks/geotags.ts b/discojs/discojs-core/src/default_tasks/geotags.ts index 53d754535..b7c39c7be 100644 --- a/discojs/discojs-core/src/default_tasks/geotags.ts +++ b/discojs/discojs-core/src/default_tasks/geotags.ts @@ -1,8 +1,8 @@ import { Range } from 'immutable' import tf from '@tensorflow/tfjs' -import type { Task, TaskProvider } from '..' -import { data } from '..' +import type { Model, Task, TaskProvider } from '..' +import { data, models } from '..' import { LabelTypeEnum } from '../task/label_type' export const geotags: TaskProvider = { @@ -44,7 +44,7 @@ export const geotags: TaskProvider = { } }, - async getModel (): Promise { + async getModel (): Promise { const pretrainedModel = await tf.loadLayersModel( 'https://storage.googleapis.com/deai-313515.appspot.com/models/geotags/model.json' ) @@ -68,6 +68,6 @@ export const geotags: TaskProvider = { metrics: ['accuracy'] }) - return model + return new models.TFJS(model) } } diff --git a/discojs/discojs-core/src/default_tasks/lus_covid.ts b/discojs/discojs-core/src/default_tasks/lus_covid.ts index 0b25c96fa..39e298a5d 100644 --- a/discojs/discojs-core/src/default_tasks/lus_covid.ts +++ b/discojs/discojs-core/src/default_tasks/lus_covid.ts @@ -1,7 +1,7 @@ import tf from '@tensorflow/tfjs' -import type { Task, TaskProvider } from '..' -import { data } from '..' +import type { Model, Task, TaskProvider } from '..' +import { data, models } from '..' export const lusCovid: TaskProvider = { getTask (): Task { @@ -40,7 +40,7 @@ export const lusCovid: TaskProvider = { } }, - async getModel (): Promise { + async getModel (): Promise { const imageHeight = 100 const imageWidth = 100 const imageChannels = 3 @@ -93,6 +93,6 @@ export const lusCovid: TaskProvider = { metrics: ['accuracy'] }) - return model + return new models.TFJS(model) } } diff --git a/discojs/discojs-core/src/default_tasks/mnist.ts b/discojs/discojs-core/src/default_tasks/mnist.ts index 199bcd5ae..d5707d76e 100644 --- a/discojs/discojs-core/src/default_tasks/mnist.ts +++ b/discojs/discojs-core/src/default_tasks/mnist.ts @@ -1,6 +1,7 @@ import tf from '@tensorflow/tfjs' -import type { Task, TaskProvider } from '..' +import type { Model, Task, TaskProvider } from '..' +import { models } from '..' export const mnist: TaskProvider = { getTask (): Task { @@ -39,7 +40,7 @@ export const mnist: TaskProvider = { } }, - async getModel (): Promise { + async getModel (): Promise { const model = tf.sequential() model.add( @@ -68,6 +69,6 @@ export const mnist: TaskProvider = { metrics: ['accuracy'] }) - return model + return new models.TFJS(model) } } diff --git a/discojs/discojs-core/src/default_tasks/simple_face.ts b/discojs/discojs-core/src/default_tasks/simple_face.ts index fd83314a5..f98ef6f10 100644 --- a/discojs/discojs-core/src/default_tasks/simple_face.ts +++ b/discojs/discojs-core/src/default_tasks/simple_face.ts @@ -1,6 +1,4 @@ -import type tf from '@tensorflow/tfjs' - -import type { Task, TaskProvider } from '..' +import type { Model, Task, TaskProvider } from '..' import { data } from '..' export const simpleFace: TaskProvider = { @@ -38,7 +36,7 @@ export const simpleFace: TaskProvider = { } }, - async getModel (): Promise { + async getModel (): Promise { throw new Error('Not implemented') } } diff --git a/discojs/discojs-core/src/default_tasks/skin_mnist.ts b/discojs/discojs-core/src/default_tasks/skin_mnist.ts index 45e24ea8f..21d2be8ee 100644 --- a/discojs/discojs-core/src/default_tasks/skin_mnist.ts +++ b/discojs/discojs-core/src/default_tasks/skin_mnist.ts @@ -1,7 +1,7 @@ import tf from '@tensorflow/tfjs' -import type { Task, TaskProvider } from '..' -import { data } from '..' +import type { Model, Task, TaskProvider } from '..' +import { data, models } from '..' export const skinMnist: TaskProvider = { getTask (): Task { @@ -47,7 +47,7 @@ export const skinMnist: TaskProvider = { } }, - async getModel (): Promise { + async getModel (): Promise { const numClasses = 7 const size = 28 @@ -98,6 +98,6 @@ export const skinMnist: TaskProvider = { metrics: ['accuracy'] }) - return model + return new models.TFJS(model) } } diff --git a/discojs/discojs-core/src/default_tasks/titanic.ts b/discojs/discojs-core/src/default_tasks/titanic.ts index c4edc3ba6..711e4431d 100644 --- a/discojs/discojs-core/src/default_tasks/titanic.ts +++ b/discojs/discojs-core/src/default_tasks/titanic.ts @@ -1,7 +1,7 @@ import tf from '@tensorflow/tfjs' -import type { Task, TaskProvider } from '..' -import { data } from '..' +import type { Model, Task, TaskProvider } from '..' +import { data, models } from '..' export const titanic: TaskProvider = { getTask (): Task { @@ -71,7 +71,7 @@ export const titanic: TaskProvider = { } }, - async getModel (): Promise { + async getModel (): Promise { const model = tf.sequential() model.add( @@ -92,6 +92,6 @@ export const titanic: TaskProvider = { metrics: ['accuracy'] }) - return model + return new models.TFJS(model) } } diff --git a/discojs/discojs-core/src/index.ts b/discojs/discojs-core/src/index.ts index fd2df191a..506111dd0 100644 --- a/discojs/discojs-core/src/index.ts +++ b/discojs/discojs-core/src/index.ts @@ -14,6 +14,9 @@ export { Memory, ModelType, type ModelInfo, type Path, type ModelSource, Empty a export { Disco, TrainingSchemes } from './training' export { Validator } from './validation' +export { Model } from './models' +export * as models from './models' + export * from './task' export * as defaultTasks from './default_tasks' diff --git a/discojs/discojs-core/src/memory/base.ts b/discojs/discojs-core/src/memory/base.ts index 59b4c85a8..7de82dad6 100644 --- a/discojs/discojs-core/src/memory/base.ts +++ b/discojs/discojs-core/src/memory/base.ts @@ -1,8 +1,7 @@ // only used browser-side // TODO: replace IO type -import type tf from '@tensorflow/tfjs' -import type { TaskID } from '..' +import type { Model, TaskID } from '..' import type { ModelType } from './model_type' /** @@ -49,7 +48,7 @@ export abstract class Memory { * @param source The model source * @returns The model */ - abstract getModel (source: ModelSource): Promise + abstract getModel (source: ModelSource): Promise /** * Removes the model identified by the given model source from memory. @@ -77,7 +76,7 @@ export abstract class Memory { * @param source The model source * @param model The new model */ - abstract updateWorkingModel (source: ModelSource, model: tf.LayersModel): Promise + abstract updateWorkingModel (source: ModelSource, model: Model): Promise /** * Creates a saved model copy from the working model identified by the given model source. @@ -94,7 +93,7 @@ export abstract class Memory { * @param model The new model * @returns The saved model's path */ - abstract saveModel (source: ModelSource, model: tf.LayersModel): Promise + abstract saveModel (source: ModelSource, model: Model): Promise /** * Moves the model identified by the model source to a file system. This is platform-dependent. diff --git a/discojs/discojs-core/src/memory/empty.ts b/discojs/discojs-core/src/memory/empty.ts index 6abeb2536..030222f31 100644 --- a/discojs/discojs-core/src/memory/empty.ts +++ b/discojs/discojs-core/src/memory/empty.ts @@ -1,4 +1,4 @@ -import type tf from '@tensorflow/tfjs' +import type { Model } from '..' import type { ModelInfo, Path } from './base' import { Memory } from './base' @@ -15,7 +15,7 @@ export class Empty extends Memory { return false } - async getModel (): Promise { + async getModel (): Promise { throw new Error('empty') } diff --git a/discojs/discojs-core/src/models/index.ts b/discojs/discojs-core/src/models/index.ts new file mode 100644 index 000000000..25b868724 --- /dev/null +++ b/discojs/discojs-core/src/models/index.ts @@ -0,0 +1,2 @@ +export { Model } from './model' +export { TFJS } from './tfjs' diff --git a/discojs/discojs-core/src/models/model.ts b/discojs/discojs-core/src/models/model.ts new file mode 100644 index 000000000..e33233304 --- /dev/null +++ b/discojs/discojs-core/src/models/model.ts @@ -0,0 +1,43 @@ +import type tf from '@tensorflow/tfjs' + +import type { WeightsContainer } from '..' +import type { EventEmitter } from 'utils/event_emitter' +import type { Dataset } from 'dataset' + +// TODO still bound to tfjs +export type EpochLogs = tf.Logs | undefined +export type Prediction = tf.Tensor +export type Sample = tf.Tensor + +// TODO remove as it's unused and kinda internal to tf +export interface Events extends Record { + batchBegin: undefined + batchEnd: undefined +} + +/** Trainable predictor */ +// TODO make it typesafe: same shape of data/input/weights +export abstract class Model { + abstract get weights (): WeightsContainer + abstract set weights (ws: WeightsContainer) + + /** + * Improve predictor + * @param trainingData dataset to optimize for + * @param validationData dataset to measure how well it is training + * @param epochs number of pass over the training datatset + * @param tracker watch the various steps + * @yields on every epoch, training can be stop by `return`ing it + */ + // TODO get rid of epoch & generator as total view is across the network + abstract train ( + trainingData: Dataset, + validationData?: Dataset, + epochs?: number, + tracker?: EventEmitter + ): AsyncGenerator + + /** Predict likely values */ + // TODO extract in separated TrainedModel? + abstract predict (input: Sample): Promise +} diff --git a/discojs/discojs-core/src/models/tfjs.ts b/discojs/discojs-core/src/models/tfjs.ts new file mode 100644 index 000000000..03a555f99 --- /dev/null +++ b/discojs/discojs-core/src/models/tfjs.ts @@ -0,0 +1,96 @@ +import tf from '@tensorflow/tfjs' + +import { Sink } from '../utils/event_emitter' +import { WeightsContainer } from '..' + +import { Model } from '.' +import type { EpochLogs, Prediction, Sample } from './model' +import type { Dataset } from '../dataset' + +export class TFJS extends Model { + constructor ( + private readonly model: tf.LayersModel + ) { + super() + + if (model.loss === null) { + throw new Error('TFJS models need to be compiled to be used') + } + } + + override get weights (): WeightsContainer { + return new WeightsContainer(this.model.weights.map((w) => w.read())) + } + + override set weights (ws: WeightsContainer) { + this.model.setWeights(ws.weights) + } + + override async * train ( + trainingData: Dataset, + validationData?: Dataset, + epochs = 1, + tracker = new Sink() + ): AsyncGenerator { + for (let i = 0; i < epochs; i++) { + let logs: tf.Logs | undefined + + await this.model.fitDataset(trainingData, { + epochs: 1, + validationData, + callbacks: { + onEpochEnd: (_, cur) => { logs = cur }, + onBatchBegin: () => { tracker.emit('batchBegin', undefined) }, + onBatchEnd: () => { tracker.emit('batchEnd', undefined) } + } + }) + + yield logs + } + } + + override async predict (input: Sample): Promise { + const ret = this.model.predict(input) + if (Array.isArray(ret)) { + throw new Error('prediction yield many Tensors but should have only returned one') + } + + return ret + } + + static async deserialize (raw: tf.io.ModelArtifacts): Promise { + return new this(await tf.loadLayersModel({ + load: async () => raw + })) + } + + async serialize (): Promise { + let resolveArtifacts: (_: tf.io.ModelArtifacts) => void + const ret = new Promise((resolve) => { resolveArtifacts = resolve }) + + await this.model.save({ + save: async (artifacts) => { + resolveArtifacts(artifacts) + return { + modelArtifactsInfo: { + dateSaved: new Date(), + modelTopologyType: 'JSON' + } + } + } + }, { + includeOptimizer: true // keep model compiled + }) + + return await ret + } + + /** + * extract wrapped model + * + * @deprecated use `Model` instead of relying on tf specifics + */ + extract (): tf.LayersModel { + return this.model + } +} diff --git a/discojs/discojs-core/src/serialization/model.spec.ts b/discojs/discojs-core/src/serialization/model.spec.ts index 728aac4ba..89261f090 100644 --- a/discojs/discojs-core/src/serialization/model.spec.ts +++ b/discojs/discojs-core/src/serialization/model.spec.ts @@ -1,26 +1,29 @@ import { assert } from 'chai' import tf from '@tensorflow/tfjs' -import { serialization } from '..' +import type { Model } from '..' +import { serialization, models } from '..' -async function getRawWeights (model: tf.LayersModel): Promise> { +async function getRawWeights (model: Model): Promise> { return Array.from( (await Promise.all( - model.getWeights().map(async (w) => await w.data<'float32'>())) + model.weights.weights.map(async (w) => await w.data<'float32'>())) ).entries() ) } describe('model', () => { it('can encode what it decodes', async () => { - const model = tf.sequential() - - model.add( - tf.layers.conv2d({ - inputShape: [32, 32, 3], - kernelSize: 3, - filters: 16, - activation: 'relu' + const model = new models.TFJS( + tf.sequential({ + layers: [ + tf.layers.conv2d({ + inputShape: [32, 32, 3], + kernelSize: 3, + filters: 16, + activation: 'relu' + }) + ] }) ) diff --git a/discojs/discojs-core/src/serialization/model.ts b/discojs/discojs-core/src/serialization/model.ts index d573709d4..fa4dcdd59 100644 --- a/discojs/discojs-core/src/serialization/model.ts +++ b/discojs/discojs-core/src/serialization/model.ts @@ -1,5 +1,8 @@ -import tf from '@tensorflow/tfjs' import msgpack from 'msgpack-lite' +import type tf from '@tensorflow/tfjs' + +import type { Model } from '..' +import { models } from '..' export type Encoded = Uint8Array @@ -7,31 +10,22 @@ export function isEncoded (raw: unknown): raw is Encoded { return raw instanceof Uint8Array } -export async function encode (model: tf.LayersModel): Promise { - const saved = await new Promise((resolve) => { - void model.save({ - save: async (artifacts) => { - resolve(artifacts) - return { - modelArtifactsInfo: { - dateSaved: new Date(), - modelTopologyType: 'JSON' - } - } - } - }, { includeOptimizer: true }) - }) +export async function encode (model: Model): Promise { + if (model instanceof models.TFJS) { + const serialized = await model.serialize() + return msgpack.encode(serialized) + } - return msgpack.encode(saved) + throw new Error('unknown model type') } -export async function decode (encoded: unknown): Promise { +export async function decode (encoded: unknown): Promise { if (!isEncoded(encoded)) { throw new Error('invalid encoding') } const raw = msgpack.decode(encoded) - return await tf.loadLayersModel({ - load: () => raw - }) + // TODO how to select model type? prepend with model id + // TODO totally unsafe casting + return await models.TFJS.deserialize(raw as tf.io.ModelArtifacts) } diff --git a/discojs/discojs-core/src/task/task_handler.ts b/discojs/discojs-core/src/task/task_handler.ts index 31c1749b1..e99021477 100644 --- a/discojs/discojs-core/src/task/task_handler.ts +++ b/discojs/discojs-core/src/task/task_handler.ts @@ -1,8 +1,9 @@ import axios from 'axios' import { Map } from 'immutable' -import type tf from '@tensorflow/tfjs' -import { serialization, WeightsContainer } from '..' +import type { Model } from '..' +import { serialization } from '..' + import type { Task, TaskID } from './task' import { isTask } from './task' @@ -11,14 +12,14 @@ const TASK_ENDPOINT = 'tasks' export async function pushTask ( url: URL, task: Task, - model: tf.LayersModel + model: Model ): Promise { await axios.post( url.href + TASK_ENDPOINT, { task, model: await serialization.model.encode(model), - weights: await serialization.weights.encode(WeightsContainer.from(model)) + weights: await serialization.weights.encode(model.weights) } ) } diff --git a/discojs/discojs-core/src/task/task_provider.ts b/discojs/discojs-core/src/task/task_provider.ts index fb4fd2433..4aeb0725e 100644 --- a/discojs/discojs-core/src/task/task_provider.ts +++ b/discojs/discojs-core/src/task/task_provider.ts @@ -1,11 +1,9 @@ -import type tf from '@tensorflow/tfjs' - -import type { Task } from '..' +import type { Model, Task } from '..' export interface TaskProvider { getTask: () => Task // Create the corresponding model ready for training (compiled) - getModel: () => Promise + getModel: () => Promise } export function isTaskProvider (obj: any): obj is TaskProvider { diff --git a/discojs/discojs-core/src/training/trainer/distributed_trainer.ts b/discojs/discojs-core/src/training/trainer/distributed_trainer.ts index 525938808..b66521be4 100644 --- a/discojs/discojs-core/src/training/trainer/distributed_trainer.ts +++ b/discojs/discojs-core/src/training/trainer/distributed_trainer.ts @@ -1,7 +1,6 @@ import type tf from '@tensorflow/tfjs' -import type { Memory, Task, TrainingInformant, client as clients } from '../..' -import { WeightsContainer } from '../..' +import type { Model, Memory, Task, TrainingInformant, client as clients } from '../..' import type { Aggregator } from '../../aggregator' import { Trainer } from './trainer' @@ -19,7 +18,7 @@ export class DistributedTrainer extends Trainer { task: Task, trainingInformant: TrainingInformant, memory: Memory, - model: tf.LayersModel, + model: Model, private readonly client: clients.Client ) { super(task, trainingInformant, memory, model) @@ -29,29 +28,22 @@ export class DistributedTrainer extends Trainer { async onTrainBegin (logs?: tf.Logs): Promise { await super.onTrainBegin(logs) - - const weights = WeightsContainer.from(this.model) - - await this.client.onTrainBeginCommunication(weights, this.trainingInformant) + await this.client.onTrainBeginCommunication(this.model.weights, this.trainingInformant) } async onRoundBegin (accuracy: number): Promise { - const weights = WeightsContainer.from(this.model) - - await this.client.onRoundBeginCommunication(weights, this.roundTracker.round, this.trainingInformant) + await this.client.onRoundBeginCommunication(this.model.weights, this.roundTracker.round, this.trainingInformant) } /** * Callback called every time a round is over */ async onRoundEnd (accuracy: number): Promise { - const weights = WeightsContainer.from(this.model) - - await this.client.onRoundEndCommunication(weights, this.roundTracker.round, this.trainingInformant) + await this.client.onRoundEndCommunication(this.model.weights, this.roundTracker.round, this.trainingInformant) if (this.aggregator.model !== undefined) { // The aggregator's own aggregation is async. The trainer updates its model to match the aggregator's // after it has completed a round of training. - this.model.setWeights(this.aggregator.model.getWeights()) + this.model.weights = this.aggregator.model.weights } await this.memory.updateWorkingModel( diff --git a/discojs/discojs-core/src/training/trainer/trainer.ts b/discojs/discojs-core/src/training/trainer/trainer.ts index 5204146d7..eea1d866f 100644 --- a/discojs/discojs-core/src/training/trainer/trainer.ts +++ b/discojs/discojs-core/src/training/trainer/trainer.ts @@ -1,10 +1,11 @@ import type tf from '@tensorflow/tfjs' -import type { Memory, Task, TrainingInformant } from '../..' +import type { Memory, Model, Task, TrainingInformant } from '../..' import { RoundTracker } from './round_tracker' import type { TrainerLog } from '../../logging/trainer_logger' import { TrainerLogger } from '../../logging/trainer_logger' +import { EventEmitter } from '../../utils/event_emitter' /** Abstract class whose role is to train a model with a given dataset. This can be either done * locally (alone) or in a distributed way with collaborators. The Trainer works as follows: @@ -18,7 +19,7 @@ import { TrainerLogger } from '../../logging/trainer_logger' export abstract class Trainer { public readonly roundTracker: RoundTracker - private stopTrainingRequested = false + private training?: AsyncGenerator private readonly trainerLogger: TrainerLogger /** @@ -30,7 +31,7 @@ export abstract class Trainer { public readonly task: Task, public readonly trainingInformant: TrainingInformant, public readonly memory: Memory, - public readonly model: tf.LayersModel + public readonly model: Model ) { this.trainerLogger = new TrainerLogger() this.roundTracker = new RoundTracker(task.trainingInformation.roundDuration) @@ -52,7 +53,6 @@ export abstract class Trainer { } this.roundTracker.updateBatch() - this.stopTrainModelIfRequested() if (this.roundTracker.roundHasEnded()) { await this.onRoundEnd(logs.acc) @@ -100,7 +100,7 @@ export abstract class Trainer { * Request stop training to be used from the Disco instance or any class that is taking care of the trainer. */ async stopTraining (): Promise { - this.stopTrainingRequested = true + await this.training?.return() } /** @@ -111,23 +111,39 @@ export abstract class Trainer { dataset: tf.data.Dataset, valDataset: tf.data.Dataset ): Promise { - this.resetStopTrainerState() + if (this.training !== undefined) { + throw new Error('training already running, cancel it before launching a new one') + } + + await this.onTrainBegin() - await this.model.fitDataset( + this.training = this.model.train( dataset, - { - epochs: this.task.trainingInformation.epochs, - validationData: valDataset, - callbacks: { - onEpochBegin: (e, l) => { this.onEpochBegin(e, l) }, - onEpochEnd: (e, l) => { this.onEpochEnd(e, l) }, - onBatchBegin: async (e, l) => { await this.onBatchBegin(e, l) }, - onBatchEnd: async (e, l) => { await this.onBatchEnd(e, l) }, - onTrainBegin: async (l) => { await this.onTrainBegin(l) }, - onTrainEnd: async (l) => { await this.onTrainEnd(l) } - } - } + valDataset, + this.task.trainingInformation.epochs, + new EventEmitter({ + // TODO implement + // epochBegin: () => this.onEpochBegin(), + // epochEnd: () => this.onEpochEnd(), + // batchBegin: async () => await this.onBatchBegin(), + // batchEnd: async () => await this.onBatchEnd(), + }) ) + + let epoch = 0 + this.onEpochBegin(epoch) + for await (const logs of this.training) { + this.onEpochEnd(epoch, logs) + + epoch += 1 + if (epoch < this.task.trainingInformation.epochs) { + this.onEpochBegin(epoch + 1) + } + } + + this.training = undefined + + await this.onTrainEnd() } /** @@ -137,24 +153,6 @@ export abstract class Trainer { return +(accuracy * 100).toFixed(decimalsToRound) } - /** - * reset stop training state - */ - protected resetStopTrainerState (): void { - this.model.stopTraining = false - this.stopTrainingRequested = false - } - - /** - * If stop training is requested, do so - */ - protected stopTrainModelIfRequested (): void { - if (this.stopTrainingRequested) { - this.model.stopTraining = true - this.stopTrainingRequested = false - } - } - getTrainerLog (): TrainerLog { return this.trainerLogger.log } diff --git a/discojs/discojs-core/src/training/trainer/trainer_builder.ts b/discojs/discojs-core/src/training/trainer/trainer_builder.ts index a0fb16ee1..2d7719349 100644 --- a/discojs/discojs-core/src/training/trainer/trainer_builder.ts +++ b/discojs/discojs-core/src/training/trainer/trainer_builder.ts @@ -1,6 +1,4 @@ -import type tf from '@tensorflow/tfjs' - -import type { client as clients, Task, TrainingInformant, ModelInfo, Memory } from '../..' +import type { client as clients, Model, Task, TrainingInformant, ModelInfo, Memory } from '../..' import { ModelType } from '../..' import type { Aggregator } from '../../aggregator' @@ -49,7 +47,7 @@ export class TrainerBuilder { * If a model exists in memory, laod it, otherwise load model from server * @returns */ - private async getModel (client: clients.Client): Promise { + private async getModel (client: clients.Client): Promise { const modelID = this.task.trainingInformation?.modelID if (modelID === undefined) { throw new TypeError('model ID is undefined') diff --git a/discojs/discojs-core/src/validation/validator.ts b/discojs/discojs-core/src/validation/validator.ts index ce0457eae..fff18c513 100644 --- a/discojs/discojs-core/src/validation/validator.ts +++ b/discojs/discojs-core/src/validation/validator.ts @@ -1,7 +1,7 @@ import { List } from 'immutable' import tf from '@tensorflow/tfjs' -import type { data, Task, Logger, client as clients, Memory, ModelSource, Features } from '..' +import type { data, Model, Task, Logger, client as clients, Memory, ModelSource, Features } from '..' import { GraphInformant } from '..' export class Validator { @@ -21,11 +21,14 @@ export class Validator { } } - private getLabel (ys: tf.Tensor, isBinary: boolean): Float32Array | Int32Array | Uint8Array { - if (isBinary) { - return ys.greaterEqual(tf.scalar(0.5)).dataSync() - } else { - return ys.argMax(1).dataSync() + private async getLabel (ys: tf.Tensor): Promise { + switch (ys.shape[1]) { + case 1: + return await ys.greaterEqual(tf.scalar(0.5)).data() + case 2: + return await ys.argMax(1).data() + default: + throw new Error(`unable to reduce tensor of shape: ${ys.shape.toString()}`) } } @@ -36,7 +39,6 @@ export class Validator { } const model = await this.getModel() - const isBinary = model.loss === 'binaryCrossentropy' let features: Features[] = [] const groundTruth: number[] = [] @@ -48,8 +50,8 @@ export class Validator { .mapAsync(async e => { if (typeof e === 'object' && 'xs' in e && 'ys' in e) { const xs = e.xs as tf.Tensor - const ys = this.getLabel(e.ys as tf.Tensor, isBinary) - const pred = this.getLabel(model.predict(xs, { batchSize }) as tf.Tensor, isBinary) + const ys = await this.getLabel(e.ys as tf.Tensor) + const pred = await this.getLabel(await model.predict(xs)) const currentFeatures = await xs.array() if (Array.isArray(currentFeatures)) { @@ -98,7 +100,6 @@ export class Validator { } const model = await this.getModel() - const isBinary = model.loss === 'binaryCrossentropy' let features: Features[] = [] // Get model prediction per batch and flatten the result @@ -114,7 +115,7 @@ export class Validator { throw new TypeError('Data format is incorrect') } - const pred = this.getLabel(model.predict(xs, { batchSize }) as tf.Tensor, isBinary) + const pred = await this.getLabel(await model.predict(xs)) return Array.from(pred) }).toArray()).flat() @@ -123,7 +124,7 @@ export class Validator { .toArray() } - async getModel (): Promise { + async getModel (): Promise { if (this.source !== undefined && await this.memory.contains(this.source)) { return await this.memory.getModel(this.source) } diff --git a/discojs/discojs-core/src/weights/weights_container.ts b/discojs/discojs-core/src/weights/weights_container.ts index 6bc262844..2cf8af864 100644 --- a/discojs/discojs-core/src/weights/weights_container.ts +++ b/discojs/discojs-core/src/weights/weights_container.ts @@ -121,13 +121,4 @@ export class WeightsContainer { static of (...weights: TensorLike[]): WeightsContainer { return new this(weights) } - - /** - * Instantiates a new weights container from the given model's weights. - * @param model The TF.js model - * @returns The instantiated weights container - */ - static from (model: tf.LayersModel): WeightsContainer { - return new this(model.weights.map((w) => w.read())) - } } diff --git a/discojs/discojs-web/src/memory/memory.ts b/discojs/discojs-web/src/memory/memory.ts index 0ba1eaa37..654818ae7 100644 --- a/discojs/discojs-web/src/memory/memory.ts +++ b/discojs/discojs-web/src/memory/memory.ts @@ -10,11 +10,11 @@ import { Map } from 'immutable' import path from 'path' import * as tf from '@tensorflow/tfjs' -import type { Path, ModelInfo, ModelSource } from '@epfml/discojs-core' -import { Memory, ModelType } from '@epfml/discojs-core' +import type { Path, Model, ModelInfo, ModelSource } from '@epfml/discojs-core' +import { Memory, ModelType, models } from '@epfml/discojs-core' export class IndexedDB extends Memory { - pathFor (source: ModelSource): Path { + override pathFor (source: ModelSource): Path { if (typeof source === 'string') { return source } @@ -28,7 +28,7 @@ export class IndexedDB extends Memory { return `indexeddb://${path.join(source.type, source.taskID, source.name)}@${version}` } - infoFor (source: ModelSource): ModelInfo { + override infoFor (source: ModelSource): ModelInfo { if (typeof source !== 'string') { return source } @@ -50,8 +50,8 @@ export class IndexedDB extends Memory { return await this.getModelMetadata(source) !== undefined } - async getModel (source: ModelSource): Promise { - return await tf.loadLayersModel(this.pathFor(source)) + override async getModel (source: ModelSource): Promise { + return new models.TFJS(await tf.loadLayersModel(this.pathFor(source))) } async deleteModel (source: ModelSource): Promise { @@ -75,13 +75,19 @@ export class IndexedDB extends Memory { * @param source the destination * @param model the model */ - async updateWorkingModel (source: ModelSource, model: tf.LayersModel): Promise { + override async updateWorkingModel (source: ModelSource, model: Model): Promise { const src: ModelInfo = this.infoFor(source) if (src.type !== undefined && src.type !== ModelType.WORKING) { throw new Error('expected working model') } + + if (model instanceof models.TFJS) { + await model.extract().save(this.pathFor({ ...src, type: ModelType.WORKING, version: 0 }), { includeOptimizer: true }) + } else { + throw new Error('unknown model type') + } + // Enforce version 0 to always keep a single working model at a time - await model.save(this.pathFor({ ...src, type: ModelType.WORKING, version: 0 }), { includeOptimizer: true }) } /** @@ -101,13 +107,17 @@ export class IndexedDB extends Memory { return dst } - async saveModel (source: ModelSource, model: tf.LayersModel): Promise { + override async saveModel (source: ModelSource, model: Model): Promise { const src: ModelInfo = this.infoFor(source) if (src.type !== undefined && src.type !== ModelType.SAVED) { throw new Error('expected saved model') } const dst = this.pathFor(await this.duplicateSource({ ...src, type: ModelType.SAVED })) - await model.save(dst, { includeOptimizer: true }) + if (model instanceof models.TFJS) { + await model.extract().save(dst, { includeOptimizer: true }) + } else { + throw new Error('unknown model type') + } return dst } diff --git a/server/src/get_server.ts b/server/src/get_server.ts index 9cc465500..15f87b93a 100644 --- a/server/src/get_server.ts +++ b/server/src/get_server.ts @@ -1,14 +1,13 @@ import cors from 'cors' import express from 'express' import expressWS from 'express-ws' -import type tf from '@tensorflow/tfjs' +import type * as http from 'http' -import type { Task, TaskProvider } from '@epfml/discojs-core' +import type { Model, Task, TaskProvider } from '@epfml/discojs-core' import { CONFIG } from './config' import { Router } from './router' import { TasksAndModels } from './tasks' -import type * as http from 'http' export class Disco { private readonly _app: express.Application @@ -29,7 +28,7 @@ export class Disco { } // If a model is not provided, its url must be provided in the task object - async addTask (task: Task | TaskProvider, model?: tf.LayersModel | URL): Promise { + async addTask (task: Task | TaskProvider, model?: Model | URL): Promise { await this.tasksAndModels.addTaskAndModel(task, model) } diff --git a/server/src/router/decentralized/server.ts b/server/src/router/decentralized/server.ts index 21e1f5dd1..86a4fcb9d 100644 --- a/server/src/router/decentralized/server.ts +++ b/server/src/router/decentralized/server.ts @@ -5,9 +5,8 @@ import type WebSocket from 'ws' import type { ParamsDictionary } from 'express-serve-static-core' import type { ParsedQs } from 'qs' import { Map, Set } from 'immutable' -import type tf from '@tensorflow/tfjs' -import type { Task, TaskID } from '@epfml/discojs-core' +import type { Model, Task, TaskID } from '@epfml/discojs-core' import { client } from '@epfml/discojs-core' import { Server } from '../server' @@ -44,12 +43,12 @@ export class Decentralized extends Server { ) } - protected initTask (task: Task, model: tf.LayersModel): void {} + protected initTask (task: Task, model: Model): void {} protected handle ( task: Task, ws: WebSocket, - model: tf.LayersModel, + model: Model, req: express.Request< ParamsDictionary, any, diff --git a/server/src/router/federated/server.ts b/server/src/router/federated/server.ts index 60ce74413..f29aa4c12 100644 --- a/server/src/router/federated/server.ts +++ b/server/src/router/federated/server.ts @@ -3,9 +3,9 @@ import type WebSocket from 'ws' import { v4 as randomUUID } from 'uuid' import { List, Map } from 'immutable' import msgpack from 'msgpack-lite' -import type tf from '@tensorflow/tfjs' import type { + Model, Task, TaskID, WeightsContainer, @@ -108,7 +108,7 @@ export class Federated extends Server { void this.storeAggregationResult(aggregator) } - protected initTask (task: Task, model: tf.LayersModel): void { + protected initTask (task: Task, model: Model): void { const aggregator = new aggregators.MeanAggregator(task, model) this.aggregators = this.aggregators.set(task.id, aggregator) @@ -200,7 +200,7 @@ export class Federated extends Server { protected handle ( task: Task, ws: WebSocket, - model: tf.LayersModel, + model: Model, req: express.Request ): void { const taskAggregator = this.aggregators.get(task.id) diff --git a/server/src/router/server.ts b/server/src/router/server.ts index fc3e53d20..a7eed2e27 100644 --- a/server/src/router/server.ts +++ b/server/src/router/server.ts @@ -1,9 +1,8 @@ import express from 'express' import type expressWS from 'express-ws' import type WebSocket from 'ws' -import type tf from '@tensorflow/tfjs' -import type { Task } from '@epfml/discojs-core' +import type { Model, Task } from '@epfml/discojs-core' import type { TasksAndModels } from '../tasks' @@ -34,7 +33,7 @@ export abstract class Server { return this.ownRouter } - private onNewTask (task: Task, model: tf.LayersModel): void { + private onNewTask (task: Task, model: Model): void { this.tasks.push(task.id) this.initTask(task, model) @@ -66,12 +65,12 @@ export abstract class Server { protected abstract buildRoute (task: Task): string - protected abstract initTask (task: Task, model: tf.LayersModel): void + protected abstract initTask (task: Task, model: Model): void protected abstract handle ( task: Task, ws: WebSocket, - model: tf.LayersModel, + model: Model, req: express.Request, ): void } diff --git a/server/src/router/tasks.ts b/server/src/router/tasks.ts index 26d7193f3..e3274e36b 100644 --- a/server/src/router/tasks.ts +++ b/server/src/router/tasks.ts @@ -1,9 +1,8 @@ import type { Request, Response } from 'express' import express from 'express' import { Set } from 'immutable' -import type tf from '@tensorflow/tfjs' -import type { Task, TaskID } from '@epfml/discojs-core' +import type { Model, Task, TaskID } from '@epfml/discojs-core' import { serialization, isTask } from '@epfml/discojs-core' import type { Config } from '../config' @@ -12,7 +11,7 @@ import type { TasksAndModels } from '../tasks' export class Tasks { private readonly ownRouter: express.Router - private tasksAndModels = Set<[Task, tf.LayersModel]>() + private tasksAndModels = Set<[Task, Model]>() constructor ( private readonly config: Config, @@ -55,7 +54,7 @@ export class Tasks { return this.ownRouter } - onNewTask (task: Task, model: tf.LayersModel): void { + onNewTask (task: Task, model: Model): void { this.ownRouter.get(`/${task.id}/:file`, (req, res, next) => { this.getLatestModel(task.id, req, res).catch(next) }) diff --git a/server/src/tasks.ts b/server/src/tasks.ts index 93a9363e9..dc879ec0e 100644 --- a/server/src/tasks.ts +++ b/server/src/tasks.ts @@ -5,21 +5,21 @@ import tf from '@tensorflow/tfjs' import '@tensorflow/tfjs-node' import type { Task, Path, Digest, TaskProvider } from '@epfml/discojs-core' -import { isTaskProvider, defaultTasks, serialization } from '@epfml/discojs-core' +import { Model, isTaskProvider, defaultTasks, models, serialization } from '@epfml/discojs-core' // default tasks and added ones // register 'taskAndModel' event to get tasks // TODO save and load from disk export class TasksAndModels { - private listeners = List<(t: Task, m: tf.LayersModel) => void>() - tasksAndModels = Set<[Task, tf.LayersModel]>() + private listeners = List<(t: Task, m: Model) => void>() + tasksAndModels = Set<[Task, Model]>() - on (_: 'taskAndModel', callback: (t: Task, m: tf.LayersModel) => void): void { + on (_: 'taskAndModel', callback: (t: Task, m: Model) => void): void { this.tasksAndModels.forEach(([t, m]) => { callback(t, m) }) this.listeners = this.listeners.push(callback) } - emit (_: 'taskAndModel', task: Task, model: tf.LayersModel): void { + emit (_: 'taskAndModel', task: Task, model: Model): void { this.listeners.forEach((listener) => { listener(task, model) }) } @@ -31,9 +31,9 @@ export class TasksAndModels { } // Returns already saved model in priority, then the model from the task definition - private async loadModelFromTask (task: Task | TaskProvider): Promise { + private async loadModelFromTask (task: Task | TaskProvider): Promise { const discoTask = isTaskProvider(task) ? task.getTask() : task - let model: tf.LayersModel | undefined + let model: Model | undefined const modelPath = `./models/${discoTask.id}/` try { @@ -45,7 +45,7 @@ export class TasksAndModels { const modelURL = discoTask.trainingInformation.modelURL if (modelURL !== undefined) { - model = await tf.loadLayersModel(modelURL) + model = new models.TFJS(await tf.loadLayersModel(modelURL)) } else if (isTaskProvider(task)) { model = await task.getModel() } else { @@ -96,22 +96,21 @@ export class TasksAndModels { } } - async addTaskAndModel (task: Task | TaskProvider, model?: tf.LayersModel | URL): Promise { - let tfModel: tf.LayersModel + async addTaskAndModel (task: Task | TaskProvider, model?: Model | URL): Promise { let discoTask: Task - if (isTaskProvider(task)) { discoTask = task.getTask() } else { discoTask = task } + let tfModel: Model if (model === undefined) { tfModel = await this.loadModelFromTask(task) - } else if (model instanceof tf.LayersModel) { + } else if (model instanceof Model) { tfModel = model } else if (model instanceof URL) { - tfModel = await tf.loadLayersModel(model.href) + tfModel = new models.TFJS(await tf.loadLayersModel(model.href)) } else { throw new Error('invalid model') } diff --git a/server/tests/e2e/federated.spec.ts b/server/tests/e2e/federated.spec.ts index 91c6db3ce..e81360690 100644 --- a/server/tests/e2e/federated.spec.ts +++ b/server/tests/e2e/federated.spec.ts @@ -4,8 +4,9 @@ import type { Server } from 'node:http' import { Range } from 'immutable' import { assert } from 'chai' +import type { WeightsContainer } from '@epfml/discojs-core' import { - WeightsContainer, Disco, TrainingSchemes, client as clients, + Disco, TrainingSchemes, client as clients, aggregator as aggregators, informant, defaultTasks } from '@epfml/discojs-core' import { NodeImageLoader, NodeTabularLoader } from '@epfml/discojs-node' @@ -32,7 +33,7 @@ describe('end-to-end federated', function () { const cifar10Task = defaultTasks.cifar10.getTask() - const data = await new NodeImageLoader(cifar10Task).loadAll(files, { labels }) + const data = await new NodeImageLoader(cifar10Task).loadAll(files, { labels, shuffle: false }) const aggregator = new aggregators.MeanAggregator(cifar10Task) const client = await getClient(clients.federated.FederatedClient, server, cifar10Task, aggregator) @@ -44,7 +45,7 @@ describe('end-to-end federated', function () { if (aggregator.model === undefined) { throw new Error('model was not set') } - return WeightsContainer.from(aggregator.model) + return aggregator.model.weights } async function titanicUser (): Promise { @@ -80,7 +81,7 @@ describe('end-to-end federated', function () { trainingInformant.validationAccuracy() > 0.6, `expected validation accuracy greater than 0.6 but got ${trainingInformant.validationAccuracy()}` ) - return WeightsContainer.from(aggregator.model) + return aggregator.model.weights } it('two cifar10 users reach consensus', async () => {