Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add fondations for LLM #643

Merged
merged 8 commits into from
Mar 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions cli/src/args.ts
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,9 @@ const unsafeArgs = parse<BenchmarkUnsafeArguments>(
)

let supportedTasks: Map<string, Task> = Map()
supportedTasks = supportedTasks.set(defaultTasks.simpleFace.getTask().taskID, defaultTasks.simpleFace.getTask())
supportedTasks = supportedTasks.set(defaultTasks.titanic.getTask().taskID, defaultTasks.titanic.getTask())
supportedTasks = supportedTasks.set(defaultTasks.cifar10.getTask().taskID, defaultTasks.cifar10.getTask())
supportedTasks = supportedTasks.set(defaultTasks.simpleFace.getTask().id, defaultTasks.simpleFace.getTask())
supportedTasks = supportedTasks.set(defaultTasks.titanic.getTask().id, defaultTasks.titanic.getTask())
supportedTasks = supportedTasks.set(defaultTasks.cifar10.getTask().id, defaultTasks.cifar10.getTask())

const task = supportedTasks.get(unsafeArgs.task)
if (task === undefined) {
Expand Down
20 changes: 12 additions & 8 deletions cli/src/cli.ts
Original file line number Diff line number Diff line change
@@ -1,41 +1,45 @@
import { Range } from 'immutable'
import type { Server } from 'node:http'

import type { TrainerLog, data, Task } from '@epfml/discojs-core'
import { Disco, TrainingSchemes } from '@epfml/discojs-core'
import { Disco, TrainingSchemes, aggregator as aggregators, client as clients } from '@epfml/discojs-core'
import { getClient, startServer } from '@epfml/disco-server'

import { startServer, saveLog } from './utils'
import { saveLog } from './utils'
import { getTaskData } from './data'
import { args } from './args'

const NUMBER_OF_USERS = args.numberOfUsers
const TASK = args.task

const infoText = `\nStarted federated training of ${TASK.taskID}`
const infoText = `\nStarted federated training of ${TASK.id}`
console.log(infoText)

console.log({ args })

async function runUser (task: Task, url: URL, data: data.DataSplit): Promise<TrainerLog> {
async function runUser (task: Task, server: Server, data: data.DataSplit): Promise<TrainerLog> {
const client = await getClient(clients.federated.FederatedClient, server, task, new aggregators.MeanAggregator(TASK))
tharvik marked this conversation as resolved.
Show resolved Hide resolved

// force the federated scheme
const scheme = TrainingSchemes.FEDERATED
const disco = new Disco(task, { scheme, url })
const disco = new Disco(task, { scheme, client })

await disco.fit(data)
await disco.close()
return await disco.logs()
}

async function main (): Promise<void> {
const [server, serverUrl] = await startServer()
const server = await startServer()

const data = await getTaskData(TASK)

const logs = await Promise.all(
Range(0, NUMBER_OF_USERS).map(async (_) => await runUser(TASK, serverUrl, data)).toArray()
Range(0, NUMBER_OF_USERS).map(async (_) => await runUser(TASK, server, data)).toArray()
)

if (args.save) {
const fileName = `${TASK.taskID}_${NUMBER_OF_USERS}users.csv`
const fileName = `${TASK.id}_${NUMBER_OF_USERS}users.csv`
saveLog(logs, fileName)
}
console.log('Shutting down the server...')
Expand Down
4 changes: 2 additions & 2 deletions cli/src/data.ts
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ async function titanicData (titanic: Task): Promise<data.DataSplit> {
}

export async function getTaskData (task: Task): Promise<data.DataSplit> {
switch (task.taskID) {
switch (task.id) {
case 'simple_face':
return await simplefaceData(task)
case 'titanic':
Expand All @@ -71,6 +71,6 @@ export async function getTaskData (task: Task): Promise<data.DataSplit> {
case 'YOUR CUSTOM TASK HERE':
throw new Error('YOUR CUSTOM FUNCTION HERE')
default:
throw new Error(`Data loader for ${task.taskID} not implemented.`)
throw new Error(`Data loader for ${task.id} not implemented.`)
}
}
32 changes: 0 additions & 32 deletions cli/src/utils.ts
Original file line number Diff line number Diff line change
@@ -1,38 +1,6 @@
import type http from 'node:http'
import fs from 'node:fs'

import type { TrainerLog } from '@epfml/discojs-core'
import { Disco } from '@epfml/disco-server'

export async function startServer (): Promise<[http.Server, URL]> {
const disco = new Disco()
await disco.addDefaultTasks()

const server = disco.serve(8000)
await new Promise((resolve, reject) => {
server.once('listening', resolve)
server.once('error', reject)
server.on('error', console.error)
})

let addr: string
const rawAddr = server.address()
if (rawAddr === null) {
throw new Error('unable to get server address')
} else if (typeof rawAddr === 'string') {
addr = rawAddr
} else if (typeof rawAddr === 'object') {
if (rawAddr.family === '4') {
addr = `${rawAddr.address}:${rawAddr.port}`
} else {
addr = `[${rawAddr.address}]:${rawAddr.port}`
}
} else {
throw new Error('unable to get address to server')
}

return [server, new URL('', `http://${addr}`)]
}

export function saveLog (logs: TrainerLog[], fileName: string): void {
const filePath = `./${fileName}`
Expand Down
2 changes: 1 addition & 1 deletion discojs/discojs-core/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
"dependencies": {
"@tensorflow/tfjs": "4",
"@types/msgpack-lite": "0.1",
"axios": "0.27",
"axios": "1",
"gpt3-tokenizer": "1",
"immutable": "4",
"isomorphic-wrtc": "1",
Expand Down
36 changes: 9 additions & 27 deletions discojs/discojs-core/src/aggregator/base.ts
Original file line number Diff line number Diff line change
@@ -1,32 +1,15 @@
import { List, Map, Set } from 'immutable'
import type tf from '@tensorflow/tfjs'
import { Map, Set } from 'immutable'

import type { client, Task, AsyncInformant } from '..'
import type { client, Model, Task, AsyncInformant } from '..'

import { EventEmitter } from '../utils/event_emitter'

export enum AggregationStep {
ADD,
UPDATE,
AGGREGATE
}

class AggregationEventEmitter<T> {
private listeners = List<[once: boolean, act: (_: T) => void]>()

on (_: 'aggregation', act: (_: T) => void): void {
this.listeners = this.listeners.push([false, act])
}

once (_: 'aggregation', act: (_: T) => void): void {
this.listeners = this.listeners.push([true, act])
}

emit (_: 'aggregation', aggregated: T): void {
const listeners = this.listeners
this.listeners = this.listeners.filterNot(([once, _]) => once)
listeners.forEach(([_, act]) => { act(aggregated) })
}
}

/**
* Main, abstract, aggregator class whose role is to buffer contributions and to produce
* a result based off their aggregation, whenever some defined condition is met.
Expand All @@ -48,7 +31,7 @@ export abstract class Base<T> {
* Triggers the resolve of the result promise and the preparation for the
* next aggregation round.
*/
private readonly eventEmitter: AggregationEventEmitter<T>
private readonly eventEmitter = new EventEmitter<{ 'aggregation': T }>()

protected informant?: AsyncInformant<T>
/**
Expand Down Expand Up @@ -76,9 +59,9 @@ export abstract class Base<T> {
*/
public readonly task: Task,
/**
* The TF.js model whose weights are updated on aggregation.
* The Model whose weights are updated on aggregation.
*/
protected _model?: tf.LayersModel,
protected _model?: Model,
/**
* The round cut-off for contributions.
*/
Expand All @@ -88,7 +71,6 @@ export abstract class Base<T> {
*/
public readonly communicationRounds = 1
) {
this.eventEmitter = new AggregationEventEmitter()
this.contributions = Map()
this._nodes = Set()

Expand Down Expand Up @@ -158,7 +140,7 @@ export abstract class Base<T> {
* Sets the aggregator's TF.js model.
* @param model The new TF.js model
*/
setModel (model: tf.LayersModel): void {
setModel (model: Model): void {
this._model = model
}

Expand Down Expand Up @@ -284,7 +266,7 @@ export abstract class Base<T> {
/**
* The aggregator's current model.
*/
get model (): tf.LayersModel | undefined {
get model (): Model | undefined {
return this._model
}

Expand Down
5 changes: 2 additions & 3 deletions discojs/discojs-core/src/aggregator/mean.spec.ts
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
import { assert, expect } from 'chai'
import type { Map } from 'immutable'
import type tf from '@tensorflow/tfjs'

import type { client, Task } from '..'
import type { client, Model, Task } from '..'
import { aggregator, defaultTasks } from '..'
import { AggregationStep } from './base'

Expand All @@ -16,7 +15,7 @@ const bufferCapacity = weights.length
export class MockMeanAggregator extends aggregator.AggregatorBase<number> {
constructor (
task: Task,
model: tf.LayersModel,
model: Model,
private readonly threshold: number,
roundCutoff = 0
) {
Expand Down
9 changes: 5 additions & 4 deletions discojs/discojs-core/src/aggregator/mean.ts
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
import type { Map } from 'immutable'
import type tf from '@tensorflow/tfjs'

import { AggregationStep, Base as Aggregator } from './base'
import type { Task, WeightsContainer, client } from '..'
import type { Model, Task, WeightsContainer, client } from '..'
import { aggregation } from '..'

/**
Expand All @@ -18,7 +17,7 @@ export class MeanAggregator extends Aggregator<WeightsContainer> {

constructor (
task: Task,
model?: tf.LayersModel,
model?: Model,
roundCutoff = 0,
threshold = 1
) {
Expand Down Expand Up @@ -69,7 +68,9 @@ export class MeanAggregator extends Aggregator<WeightsContainer> {
aggregate (): void {
this.log(AggregationStep.AGGREGATE)
const result = aggregation.avg(this.contributions.get(0)?.values() as Iterable<WeightsContainer>)
this.model?.setWeights(result.weights)
if (this.model !== undefined) {
this.model.weights = result
}
this.emit(result)
}

Expand Down
8 changes: 5 additions & 3 deletions discojs/discojs-core/src/aggregator/secure.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ import { Map, List, Range } from 'immutable'
import tf from '@tensorflow/tfjs'

import { AggregationStep, Base as Aggregator } from './base'
import type { Task, WeightsContainer, client } from '..'
import type { Model, Task, WeightsContainer, client } from '..'
import { aggregation } from '..'

/**
Expand All @@ -20,7 +20,7 @@ export class SecureAggregator extends Aggregator<WeightsContainer> {

constructor (
task: Task,
model?: tf.LayersModel
model?: Model
) {
super(task, model, 0, 2)

Expand All @@ -36,7 +36,9 @@ export class SecureAggregator extends Aggregator<WeightsContainer> {
} else if (this.communicationRound === 1) {
// Average the received partial sums
const result = aggregation.avg(this.contributions.get(1)?.values() as Iterable<WeightsContainer>)
this.model?.setWeights(result.weights)
if (this.model !== undefined) {
this.model.weights = result
}
this.emit(result)
} else {
throw new Error('communication round is out of bounds')
Expand Down
2 changes: 1 addition & 1 deletion discojs/discojs-core/src/async_informant.ts
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ export class AsyncInformant<T> {

// Debug
public printAllInfos (): void {
console.debug('task:', this.aggregator.task.taskID)
console.debug('task:', this.aggregator.task.id)
console.debug('round:', this.round)
console.debug('participants:', this.currentNumberOfParticipants)
console.debug('total:', this.totalNumberOfParticipants)
Expand Down
12 changes: 5 additions & 7 deletions discojs/discojs-core/src/client/base.ts
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
import axios from 'axios'
import type { Set } from 'immutable'
import type tf from '@tensorflow/tfjs'

import type { Task, TrainingInformant, WeightsContainer } from '..'
import type { Model, Task, TrainingInformant, WeightsContainer } from '..'
import { serialization } from '..'
import type { NodeID } from './types'
import type { EventConnection } from './event_connection'
Expand Down Expand Up @@ -55,16 +54,15 @@ export abstract class Base {
* Fetches the latest model available on the network's server, for the adequate task.
* @returns The latest model
*/
async getLatestModel (): Promise<tf.LayersModel> {
async getLatestModel (): Promise<Model> {
const url = new URL('', this.url.href)
if (!url.pathname.endsWith('/')) {
url.pathname += '/'
}
url.pathname += `tasks/${this.task.taskID}/model.json`
url.pathname += `tasks/${this.task.id}/model.json`

const response = await axios.get(url.href)

return await serialization.model.decode(response.data)
const response = await axios.get<ArrayBuffer>(url.href, { responseType: 'arraybuffer' })
return await serialization.model.decode(new Uint8Array(response.data))
}

/**
Expand Down
2 changes: 1 addition & 1 deletion discojs/discojs-core/src/client/decentralized/base.ts
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ export class Base extends Client {
default:
throw new Error(`unknown protocol: ${this.url.protocol}`)
}
serverURL.pathname += `deai/${this.task.taskID}`
serverURL.pathname += `deai/${this.task.id}`

this._server = await this.connectServer(serverURL)

Expand Down
Loading
Loading