Skip to content

Commit

Permalink
discojs-core/train/models: initial
Browse files Browse the repository at this point in the history
  • Loading branch information
tharvik committed Mar 1, 2024
1 parent 6951421 commit 802b63b
Show file tree
Hide file tree
Showing 35 changed files with 342 additions and 216 deletions.
11 changes: 5 additions & 6 deletions discojs/discojs-core/src/aggregator/base.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import { Map, Set } from 'immutable'
import type tf from '@tensorflow/tfjs'

import type { client, Task, AsyncInformant } from '..'
import type { client, Model, Task, AsyncInformant } from '..'

import { EventEmitter } from '../utils/event_emitter'

Expand Down Expand Up @@ -60,9 +59,9 @@ export abstract class Base<T> {
*/
public readonly task: Task,
/**
* The TF.js model whose weights are updated on aggregation.
* The Model whose weights are updated on aggregation.
*/
protected _model?: tf.LayersModel,
protected _model?: Model,
/**
* The round cut-off for contributions.
*/
Expand Down Expand Up @@ -141,7 +140,7 @@ export abstract class Base<T> {
* Sets the aggregator's TF.js model.
* @param model The new TF.js model
*/
setModel (model: tf.LayersModel): void {
setModel (model: Model): void {
this._model = model
}

Expand Down Expand Up @@ -267,7 +266,7 @@ export abstract class Base<T> {
/**
* The aggregator's current model.
*/
get model (): tf.LayersModel | undefined {
get model (): Model | undefined {
return this._model
}

Expand Down
5 changes: 2 additions & 3 deletions discojs/discojs-core/src/aggregator/mean.spec.ts
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
import { assert, expect } from 'chai'
import type { Map } from 'immutable'
import type tf from '@tensorflow/tfjs'

import type { client, Task } from '..'
import type { client, Model, Task } from '..'
import { aggregator, defaultTasks } from '..'
import { AggregationStep } from './base'

Expand All @@ -16,7 +15,7 @@ const bufferCapacity = weights.length
export class MockMeanAggregator extends aggregator.AggregatorBase<number> {
constructor (
task: Task,
model: tf.LayersModel,
model: Model,
private readonly threshold: number,
roundCutoff = 0
) {
Expand Down
9 changes: 5 additions & 4 deletions discojs/discojs-core/src/aggregator/mean.ts
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
import type { Map } from 'immutable'
import type tf from '@tensorflow/tfjs'

import { AggregationStep, Base as Aggregator } from './base'
import type { Task, WeightsContainer, client } from '..'
import type { Model, Task, WeightsContainer, client } from '..'
import { aggregation } from '..'

/**
Expand All @@ -18,7 +17,7 @@ export class MeanAggregator extends Aggregator<WeightsContainer> {

constructor (
task: Task,
model?: tf.LayersModel,
model?: Model,
roundCutoff = 0,
threshold = 1
) {
Expand Down Expand Up @@ -69,7 +68,9 @@ export class MeanAggregator extends Aggregator<WeightsContainer> {
aggregate (): void {
this.log(AggregationStep.AGGREGATE)
const result = aggregation.avg(this.contributions.get(0)?.values() as Iterable<WeightsContainer>)
this.model?.setWeights(result.weights)
if (this.model !== undefined) {
this.model.weights = result
}
this.emit(result)
}

Expand Down
8 changes: 5 additions & 3 deletions discojs/discojs-core/src/aggregator/secure.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ import { Map, List, Range } from 'immutable'
import tf from '@tensorflow/tfjs'

import { AggregationStep, Base as Aggregator } from './base'
import type { Task, WeightsContainer, client } from '..'
import type { Model, Task, WeightsContainer, client } from '..'
import { aggregation } from '..'

/**
Expand All @@ -20,7 +20,7 @@ export class SecureAggregator extends Aggregator<WeightsContainer> {

constructor (
task: Task,
model?: tf.LayersModel
model?: Model
) {
super(task, model, 0, 2)

Expand All @@ -36,7 +36,9 @@ export class SecureAggregator extends Aggregator<WeightsContainer> {
} else if (this.communicationRound === 1) {
// Average the received partial sums
const result = aggregation.avg(this.contributions.get(1)?.values() as Iterable<WeightsContainer>)
this.model?.setWeights(result.weights)
if (this.model !== undefined) {
this.model.weights = result
}
this.emit(result)
} else {
throw new Error('communication round is out of bounds')
Expand Down
5 changes: 2 additions & 3 deletions discojs/discojs-core/src/client/base.ts
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
import axios from 'axios'
import type { Set } from 'immutable'
import type tf from '@tensorflow/tfjs'

import type { Task, TrainingInformant, WeightsContainer } from '..'
import type { Model, Task, TrainingInformant, WeightsContainer } from '..'
import { serialization } from '..'
import type { NodeID } from './types'
import type { EventConnection } from './event_connection'
Expand Down Expand Up @@ -55,7 +54,7 @@ export abstract class Base {
* Fetches the latest model available on the network's server, for the adequate task.
* @returns The latest model
*/
async getLatestModel (): Promise<tf.LayersModel> {
async getLatestModel (): Promise<Model> {
const url = new URL('', this.url.href)
if (!url.pathname.endsWith('/')) {
url.pathname += '/'
Expand Down
8 changes: 4 additions & 4 deletions discojs/discojs-core/src/default_tasks/cifar10.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import tf from '@tensorflow/tfjs'

import type { Task, TaskProvider } from '..'
import { data } from '..'
import type { Model, Task, TaskProvider } from '..'
import { data, models } from '..'

export const cifar10: TaskProvider = {
getTask (): Task {
Expand Down Expand Up @@ -40,7 +40,7 @@ export const cifar10: TaskProvider = {
}
},

async getModel (): Promise<tf.LayersModel> {
async getModel (): Promise<Model> {
const mobilenet = await tf.loadLayersModel(
'https://storage.googleapis.com/tfjs-models/tfjs/mobilenet_v1_0.25_224/model.json'
)
Expand All @@ -61,6 +61,6 @@ export const cifar10: TaskProvider = {
metrics: ['accuracy']
})

return model
return new models.TFJS(model)
}
}
8 changes: 4 additions & 4 deletions discojs/discojs-core/src/default_tasks/geotags.ts
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import { Range } from 'immutable'
import tf from '@tensorflow/tfjs'

import type { Task, TaskProvider } from '..'
import { data } from '..'
import type { Model, Task, TaskProvider } from '..'
import { data, models } from '..'
import { LabelTypeEnum } from '../task/label_type'

export const geotags: TaskProvider = {
Expand Down Expand Up @@ -44,7 +44,7 @@ export const geotags: TaskProvider = {
}
},

async getModel (): Promise<tf.LayersModel> {
async getModel (): Promise<Model> {
const pretrainedModel = await tf.loadLayersModel(
'https://storage.googleapis.com/deai-313515.appspot.com/models/geotags/model.json'
)
Expand All @@ -68,6 +68,6 @@ export const geotags: TaskProvider = {
metrics: ['accuracy']
})

return model
return new models.TFJS(model)
}
}
8 changes: 4 additions & 4 deletions discojs/discojs-core/src/default_tasks/lus_covid.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import tf from '@tensorflow/tfjs'

import type { Task, TaskProvider } from '..'
import { data } from '..'
import type { Model, Task, TaskProvider } from '..'
import { data, models } from '..'

export const lusCovid: TaskProvider = {
getTask (): Task {
Expand Down Expand Up @@ -40,7 +40,7 @@ export const lusCovid: TaskProvider = {
}
},

async getModel (): Promise<tf.LayersModel> {
async getModel (): Promise<Model> {
const imageHeight = 100
const imageWidth = 100
const imageChannels = 3
Expand Down Expand Up @@ -93,6 +93,6 @@ export const lusCovid: TaskProvider = {
metrics: ['accuracy']
})

return model
return new models.TFJS(model)
}
}
7 changes: 4 additions & 3 deletions discojs/discojs-core/src/default_tasks/mnist.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import tf from '@tensorflow/tfjs'

import type { Task, TaskProvider } from '..'
import type { Model, Task, TaskProvider } from '..'
import { models } from '..'

export const mnist: TaskProvider = {
getTask (): Task {
Expand Down Expand Up @@ -39,7 +40,7 @@ export const mnist: TaskProvider = {
}
},

async getModel (): Promise<tf.LayersModel> {
async getModel (): Promise<Model> {
const model = tf.sequential()

model.add(
Expand Down Expand Up @@ -68,6 +69,6 @@ export const mnist: TaskProvider = {
metrics: ['accuracy']
})

return model
return new models.TFJS(model)
}
}
8 changes: 4 additions & 4 deletions discojs/discojs-core/src/default_tasks/simple_face.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import tf from '@tensorflow/tfjs'

import type { Task, TaskProvider } from '..'
import { data } from '..'
import type { Model, Task, TaskProvider } from '..'
import { data, models } from '..'

export const simpleFace: TaskProvider = {
getTask (): Task {
Expand Down Expand Up @@ -37,7 +37,7 @@ export const simpleFace: TaskProvider = {
}
},

async getModel (): Promise<tf.LayersModel> {
async getModel (): Promise<Model> {
const model = await tf.loadLayersModel(
'https://storage.googleapis.com/deai-313515.appspot.com/models/mobileNetV2_35_alpha_2_classes/model.json'
)
Expand All @@ -48,6 +48,6 @@ export const simpleFace: TaskProvider = {
metrics: ['accuracy']
})

return model
return new models.TFJS(model)
}
}
8 changes: 4 additions & 4 deletions discojs/discojs-core/src/default_tasks/skin_mnist.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import tf from '@tensorflow/tfjs'

import type { Task, TaskProvider } from '..'
import { data } from '..'
import type { Model, Task, TaskProvider } from '..'
import { data, models } from '..'

export const skinMnist: TaskProvider = {
getTask (): Task {
Expand Down Expand Up @@ -47,7 +47,7 @@ export const skinMnist: TaskProvider = {
}
},

async getModel (): Promise<tf.LayersModel> {
async getModel (): Promise<Model> {
const numClasses = 7
const size = 28

Expand Down Expand Up @@ -98,6 +98,6 @@ export const skinMnist: TaskProvider = {
metrics: ['accuracy']
})

return model
return new models.TFJS(model)
}
}
8 changes: 4 additions & 4 deletions discojs/discojs-core/src/default_tasks/titanic.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import tf from '@tensorflow/tfjs'

import type { Task, TaskProvider } from '..'
import { data } from '..'
import type { Model, Task, TaskProvider } from '..'
import { data, models } from '..'

export const titanic: TaskProvider = {
getTask (): Task {
Expand Down Expand Up @@ -71,7 +71,7 @@ export const titanic: TaskProvider = {
}
},

async getModel (): Promise<tf.LayersModel> {
async getModel (): Promise<Model> {
const model = tf.sequential()

model.add(
Expand All @@ -92,6 +92,6 @@ export const titanic: TaskProvider = {
metrics: ['accuracy']
})

return model
return new models.TFJS(model)
}
}
3 changes: 3 additions & 0 deletions discojs/discojs-core/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@ export { Memory, ModelType, type ModelInfo, type Path, type ModelSource, Empty a
export { Disco, TrainingSchemes } from './training'
export { Validator } from './validation'

export { Model } from './models'
export * as models from './models'

export * from './task'
export * as defaultTasks from './default_tasks'

Expand Down
9 changes: 4 additions & 5 deletions discojs/discojs-core/src/memory/base.ts
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
// only used browser-side
// TODO: replace IO type
import type tf from '@tensorflow/tfjs'

import type { TaskID } from '..'
import type { Model, TaskID } from '..'
import type { ModelType } from './model_type'

/**
Expand Down Expand Up @@ -49,7 +48,7 @@ export abstract class Memory {
* @param source The model source
* @returns The model
*/
abstract getModel (source: ModelSource): Promise<tf.LayersModel>
abstract getModel (source: ModelSource): Promise<Model>

/**
* Removes the model identified by the given model source from memory.
Expand Down Expand Up @@ -77,7 +76,7 @@ export abstract class Memory {
* @param source The model source
* @param model The new model
*/
abstract updateWorkingModel (source: ModelSource, model: tf.LayersModel): Promise<void>
abstract updateWorkingModel (source: ModelSource, model: Model): Promise<void>

/**
* Creates a saved model copy from the working model identified by the given model source.
Expand All @@ -94,7 +93,7 @@ export abstract class Memory {
* @param model The new model
* @returns The saved model's path
*/
abstract saveModel (source: ModelSource, model: tf.LayersModel): Promise<Path | undefined>
abstract saveModel (source: ModelSource, model: Model): Promise<Path | undefined>

/**
* Moves the model identified by the model source to a file system. This is platform-dependent.
Expand Down
4 changes: 2 additions & 2 deletions discojs/discojs-core/src/memory/empty.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import type tf from '@tensorflow/tfjs'
import type { Model } from '..'

import type { ModelInfo, Path } from './base'
import { Memory } from './base'
Expand All @@ -15,7 +15,7 @@ export class Empty extends Memory {
return false
}

async getModel (): Promise<tf.LayersModel> {
async getModel (): Promise<Model> {
throw new Error('empty')
}

Expand Down
2 changes: 2 additions & 0 deletions discojs/discojs-core/src/models/index.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
export { Model } from './model'
export { TFJS } from './tfjs'
Loading

0 comments on commit 802b63b

Please sign in to comment.