diff --git a/.github/workflows/lint-test-build.yml b/.github/workflows/lint-test-build.yml index c604358fa..d0f593b72 100644 --- a/.github/workflows/lint-test-build.yml +++ b/.github/workflows/lint-test-build.yml @@ -277,9 +277,9 @@ jobs: browser: chromium start: npm start install: false - wait-on: http://localhost:8080/ + wait-on: http://localhost:8081/ working-directory: ./web-client - config: baseUrl=http://localhost:8080/#/ + config: baseUrl=http://localhost:8081/#/ test-cli: needs: diff --git a/discojs/discojs-core/src/client/event_connection.ts b/discojs/discojs-core/src/client/event_connection.ts index 56d622453..d7b5f3179 100644 --- a/discojs/discojs-core/src/client/event_connection.ts +++ b/discojs/discojs-core/src/client/event_connection.ts @@ -126,7 +126,7 @@ export class WebSocketServer implements EventConnection { return await new Promise((resolve, reject) => { ws.onerror = (err: isomorphic.ErrorEvent) => - reject(new Error(`connecting server: ${err.message}`)) // eslint-disable-line @typescript-eslint/restrict-template-expressions + reject(new Error(`Server unreachable: ${err.message}`)) // eslint-disable-line @typescript-eslint/restrict-template-expressions ws.onopen = () => resolve(server) }) } diff --git a/discojs/discojs-core/src/dataset/data/preprocessing/tabular_preprocessing.ts b/discojs/discojs-core/src/dataset/data/preprocessing/tabular_preprocessing.ts index 34e4ddd67..c9509528e 100644 --- a/discojs/discojs-core/src/dataset/data/preprocessing/tabular_preprocessing.ts +++ b/discojs/discojs-core/src/dataset/data/preprocessing/tabular_preprocessing.ts @@ -18,10 +18,16 @@ interface TabularEntry extends tf.TensorContainerObject { const sanitize: PreprocessingFunction = { type: TabularPreprocessing.Sanitize, apply: (entry: tf.TensorContainer, task: Task): tf.TensorContainer => { - const { xs, ys } = entry as TabularEntry - return { - xs: xs.map(i => i === undefined ? 0 : i), - ys: ys + // if preprocessing a dataset without labels, then the entry is an array of numbers + if (Array.isArray(entry)) { + return entry.map(i => i === undefined ? 0 : i) + // otherwise it is an object with feature and labels + } else { + const { xs, ys } = entry as TabularEntry + return { + xs: xs.map(i => i === undefined ? 0 : i), + ys: ys + } } } } diff --git a/discojs/discojs-core/src/dataset/dataset_builder.ts b/discojs/discojs-core/src/dataset/dataset_builder.ts index 7acf15faa..572ebf920 100644 --- a/discojs/discojs-core/src/dataset/dataset_builder.ts +++ b/discojs/discojs-core/src/dataset/dataset_builder.ts @@ -94,7 +94,7 @@ export class DatasetBuilder { } async build (config?: DataConfig): Promise { - // Require that at leat one source collection is non-empty, but not both + // Require that at least one source collection is non-empty, but not both if ((this._sources.length > 0) === (this.labelledSources.size > 0)) { throw new Error('Please provide dataset input files') } diff --git a/discojs/discojs-core/src/informant/training_informant/base.ts b/discojs/discojs-core/src/informant/training_informant/base.ts index edddb7a60..a62077c84 100644 --- a/discojs/discojs-core/src/informant/training_informant/base.ts +++ b/discojs/discojs-core/src/informant/training_informant/base.ts @@ -35,8 +35,12 @@ export abstract class Base { return this.messages.toArray() } + /** + * + * @returns the training round incremented by 1 (to start at 1 rather than 0) + */ round (): number { - return this.currentRound + return this.currentRound + 1 } participants (): number { diff --git a/discojs/discojs-core/src/informant/training_informant/federated.ts b/discojs/discojs-core/src/informant/training_informant/federated.ts index 4b7b7e8ee..67b0d73fc 100644 --- a/discojs/discojs-core/src/informant/training_informant/federated.ts +++ b/discojs/discojs-core/src/informant/training_informant/federated.ts @@ -12,7 +12,7 @@ export class FederatedInformant extends Base { * @param receivedStatistics statistics received from the server. */ update (receivedStatistics: Record): void { - this.currentRound = receivedStatistics.round + this.currentRound = receivedStatistics.round + 1 this.currentNumberOfParticipants = receivedStatistics.currentNumberOfParticipants this.totalNumberOfParticipants = receivedStatistics.totalNumberOfParticipants this.averageNumberOfParticipants = receivedStatistics.averageNumberOfParticipants diff --git a/discojs/discojs-core/src/training/disco.ts b/discojs/discojs-core/src/training/disco.ts index 02788cb87..6263a8f06 100644 --- a/discojs/discojs-core/src/training/disco.ts +++ b/discojs/discojs-core/src/training/disco.ts @@ -110,11 +110,8 @@ export class Disco { * @param dataTuple The data tuple */ async fit (dataTuple: data.DataSplit): Promise { - this.logger.success('Thank you for your contribution. Data preprocessing has started') - const trainData = dataTuple.train.preprocess().batch() const validationData = dataTuple.validation?.preprocess().batch() ?? trainData - await this.client.connect() const trainer = await this.trainer await trainer.fitModel(trainData.dataset, validationData.dataset) @@ -126,8 +123,6 @@ export class Disco { async pause (): Promise { const trainer = await this.trainer await trainer.stopTraining() - - this.logger.success('Training was successfully interrupted.') } /** diff --git a/discojs/discojs-core/src/validation/validator.ts b/discojs/discojs-core/src/validation/validator.ts index f308b284e..d47e91073 100644 --- a/discojs/discojs-core/src/validation/validator.ts +++ b/discojs/discojs-core/src/validation/validator.ts @@ -15,7 +15,7 @@ export class Validator { private readonly client?: clients.Client ) { if (source === undefined && client === undefined) { - throw new Error('cannot identify model') + throw new Error('To initialize a Validator, either or both a source and client need to be specified') } } @@ -30,46 +30,42 @@ export class Validator { async assess (data: data.Data, useConfusionMatrix?: boolean): Promise> { const batchSize = this.task.trainingInformation?.batchSize if (batchSize === undefined) { - throw new TypeError('batch size is undefined') + throw new TypeError('Batch size is undefined') } const model = await this.getModel() let features: Features[] = [] const groundTruth: number[] = [] - const predictions: number[] = [] let hits = 0 - await data.preprocess().batch().dataset.forEachAsync((e) => { - if (typeof e === 'object' && 'xs' in e && 'ys' in e) { - const xs = e.xs as tf.Tensor - - const ys = this.getLabel(e.ys as tf.Tensor) - const pred = this.getLabel(model.predict(xs, { batchSize }) as tf.Tensor) - - const currentFeatures = xs.arraySync() - - if (Array.isArray(currentFeatures)) { - features = features.concat(currentFeatures) + // Get model predictions per batch and flatten the result + // Also build the features and groudTruth arrays + const predictions: number[] = (await data.preprocess().dataset.batch(batchSize) + .mapAsync(async e => { + if (typeof e === 'object' && 'xs' in e && 'ys' in e) { + const xs = e.xs as tf.Tensor + const ys = this.getLabel(e.ys as tf.Tensor) + const pred = this.getLabel(model.predict(xs, { batchSize }) as tf.Tensor) + + const currentFeatures = await xs.array() + if (Array.isArray(currentFeatures)) { + features = features.concat(currentFeatures) + } else { + throw new TypeError('Data format is incorrect') + } + groundTruth.push(...Array.from(ys)) + this.size += xs.shape[0] + hits += List(pred).zip(List(ys)).filter(([p, y]) => p === y).size + // TODO: Confusion Matrix stats + const currentAccuracy = hits / this.size + this.graphInformant.updateAccuracy(currentAccuracy) + return Array.from(pred) } else { - throw new TypeError('features array is not correct') + throw new Error('Input data is missing a feature or the label') } + }).toArray()).flat() - groundTruth.push(...Array.from(ys)) - predictions.push(...Array.from(pred)) - - this.size += xs.shape[0] - - hits += List(pred).zip(List(ys)).filter(([p, y]) => p === y).size - - // TODO: Confusion Matrix stats - - const currentAccuracy = hits / this.size - this.graphInformant.updateAccuracy(currentAccuracy) - } else { - throw new Error('missing feature/label in dataset') - } - }) this.logger.success(`Obtained validation accuracy of ${this.accuracy}`) this.logger.success(`Visited ${this.visitedSamples} samples`) @@ -92,21 +88,35 @@ export class Validator { .toArray() } - async predict (data: data.Data): Promise { + async predict (data: data.Data): Promise> { const batchSize = this.task.trainingInformation?.batchSize if (batchSize === undefined) { - throw new TypeError('batch size is undefined') + throw new TypeError('Batch size is undefined') } const model = await this.getModel() - const predictions: number[] = [] + let features: Features[] = [] + + // Get model prediction per batch and flatten the result + // Also incrementally build the features array + const predictions: number[] = (await data.preprocess().dataset.batch(batchSize) + .mapAsync(async e => { + const xs = e as tf.Tensor + const currentFeatures = await xs.array() + + if (Array.isArray(currentFeatures)) { + features = features.concat(currentFeatures) + } else { + throw new TypeError('Data format is incorrect') + } - await data.dataset - .batch(batchSize) - .forEachAsync(e => - predictions.push(...(model.predict(e as tf.Tensor, { batchSize: batchSize }) as tf.Tensor).argMax(1).arraySync() as number[])) + const pred = this.getLabel(model.predict(xs, { batchSize }) as tf.Tensor) + return Array.from(pred) + }).toArray()).flat() - return predictions + return List(features).zip(List(predictions)) + .map(([f, p]) => ({ features: f, pred: p })) + .toArray() } async getModel (): Promise { @@ -118,7 +128,7 @@ export class Validator { return await this.client.getLatestModel() } - throw new Error('cannot identify model') + throw new Error('Could not load the model') } get accuracyData (): List { diff --git a/web-client/package.json b/web-client/package.json index 1f75fa17e..873643eea 100644 --- a/web-client/package.json +++ b/web-client/package.json @@ -2,7 +2,7 @@ "name": "@epfml/disco-web-client", "private": true, "scripts": { - "start": "vue-cli-service serve", + "start": "vue-cli-service serve --port 8081", "build": "vue-cli-service build", "lint": "vue-cli-service lint", "test": "vue-cli-service test:unit tests" diff --git a/web-client/src/components/pages/NotFound.vue b/web-client/src/components/pages/NotFound.vue index ff7396ace..4d4f1b347 100644 --- a/web-client/src/components/pages/NotFound.vue +++ b/web-client/src/components/pages/NotFound.vue @@ -19,10 +19,10 @@ >