diff --git a/cli/README.md b/cli/README.md index 5fdd9a759..b2f2bc903 100644 --- a/cli/README.md +++ b/cli/README.md @@ -5,7 +5,7 @@ The CLI lets one use DISCO in standalone manner (i.e. without running a server o For example, the following command trains a model on CIFAR10, using 4 federated clients for 15 epochs with a round duration of 5 epochs (see [DISCOJS.md](../docs/DISCOJS.md#rounds) for more information on rounds) > [!NOTE] -> Make sure you first ran `./get_training_data.sh` (in the root folder) to download training data. +> Make sure you first ran `./datasets/populate` (from the root folder) to download training data. ``` # From the root folder @@ -35,3 +35,12 @@ You should now be able to run your task as follows: ``` npm -w cli start -- --task your_task --numberOfUsers 4 --epochs 15 --roundDuration 5 ``` + +## Benchmarking GPT-TF.js + +The CLI also allows benchmarking the time and memory requirements of the gpt-tfjs implementation in DISCO. The last benchmark has been reported in [this PR](https://github.com/epfml/disco/pull/659). +CLI options can be listed with `npm -w cli run benchmark_gpt -- -h`. + +To benchmark model training, you can run `npm -w cli run benchmark_gpt -- --modelType gpt-nano --contextLength 128 --batchSize 8`. + +For inference run `npm -w cli run benchmark_gpt -- --inference --modelPath `. You can use the `docs/example/wikitext` example script to train a model. The model needs to be trained on the wikitext default task to ensure that model parameters such as vocab size, tokenizer, max sequence length are the same between training and inference. diff --git a/cli/package.json b/cli/package.json index 09052f1ab..51d9c6d04 100644 --- a/cli/package.json +++ b/cli/package.json @@ -6,6 +6,7 @@ "scripts": { "watch": "nodemon --ext ts --ignore dist --watch ../discojs/discojs-node/dist --watch ../server/dist --watch . --exec npm run", "start": "npm run build && node dist/cli.js", + "benchmark_gpt": "npm run build && node dist/benchmark_gpt.js", "build": "tsc", "lint": "npx eslint .", "test": ": nothing" diff --git a/cli/src/benchmark_gpt.ts b/cli/src/benchmark_gpt.ts new file mode 100644 index 000000000..462b94f6a --- /dev/null +++ b/cli/src/benchmark_gpt.ts @@ -0,0 +1,131 @@ +import { parse } from 'ts-command-line-args'; +import type { Task } from '@epfml/discojs-core' +import { fetchTasks, data, models } from '@epfml/discojs-core' +import { NodeTextLoader, loadModelFromDisk } from '@epfml/discojs-node' +import { startServer } from '@epfml/disco-server' + +interface CLIArguments{ + modelType?: string; // 'gpt-nano', 'gpt-micro', 'gpt-mini', 'gpt2' + contextLength?: number; // 128, 256, 512, 1024, 2048 + batchSize?: number; // 8, 16, 32, 64 + inference?: boolean; // benchmark inference if true, training otherwise + modelPath?: string; + help?: boolean // print help +} + +const parsedArgs = parse({ + modelType: { type: String, optional: true, description: "A GPT architecture: 'gpt-nano', 'gpt-micro', 'gpt-mini', 'gpt2'" }, + contextLength: { type: Number, optional: true, description: "The maximum input sequence length to train the model on" }, + batchSize: { type: Number, optional: true, description: "The model training bat size" }, + inference: { type: Boolean, optional: true, description: "Whether to benchmark the model inference or training" }, + modelPath: { type: String, optional: true, description: "If benchmarking inference, the path to the trained model" }, + help: { type: Boolean, optional: true, alias: 'h', description: 'Prints this usage guide' }, +}, {helpArg: 'help'}); + +const defaultArgs: Required = { + modelType: 'gpt-nano', + contextLength: 128, + batchSize: 8, + inference: false, + modelPath: 'models/model.json', + help: false +} + +// Fill parsed args with default args +const args = { ...defaultArgs, ...parsedArgs } + +/** + * Benchmark results are reported in https://github.com/epfml/disco/pull/659 + */ + +async function main(args: Required): Promise { + const { inference: benchmarkInference, modelType, + contextLength, batchSize, modelPath } = args + + // Launch a server instance + const [server, url] = await startServer() + + // const url = new URL('http://localhost:8080') + + // Fetch the wikitext task from the server + const tasks = await fetchTasks(url) + const task = tasks.get('wikitext-103') + if (task === undefined) { throw new Error('task not found') } + + /** + * Training benchmark + */ + if (!benchmarkInference) { + // Benchmark parameters + const epoch = 1 + const iterationsPerEpoch = 10 + + const config: models.GPTConfig = { + modelType: modelType as models.GPTConfig['modelType'], + maxIter: iterationsPerEpoch, + blockSize: contextLength, + lr: 0.0001, + vocabSize: 50258 // default wikitext task uses the gpt2 tokenizer with vocabSize 50258 + } + + // Load the dataset after setting the Task batch size and max sequence length + // to make sure the dataset is batched and tokenized correctly + task.trainingInformation.batchSize = batchSize + task.trainingInformation.maxSequenceLength = contextLength + const dataset = await loadWikitextData(task) + const preprocessedDataset = dataset.train.preprocess().batch().dataset + + // Init and train the model + const model = new models.GPT(config) + console.log(`\tmodel type ${modelType} \n\tbatch size ${batchSize} \n\tcontext length ${contextLength}`) + + let epochTime = performance.now() + const logGenerator = model.train(preprocessedDataset, undefined, epoch) + for await (const logs of logGenerator) { + epochTime = (performance.now() - epochTime) + const msPerToken = epochTime / (batchSize * contextLength * iterationsPerEpoch * epoch) + console.log(`\t\tTraining time: ${msPerToken.toFixed(2)} ms/token
${logs.peakMemory.toFixed(2)} GB`) + } + + /** + * Inference benchmark + */ + } else { + const model = await loadModelFromDisk(modelPath) + if (!(model instanceof models.GPT)){ + throw new Error("Loaded model isn't a GPT model") + } + // Retrieve the tokenizer used during training + const tokenizer = await models.getTaskTokenizer(task) + + // Benchmark parameters + const prompt = 'The game began development in 2010 , carrying over a large portion, The game began development in 2010 , carrying over a large portion, The game began development in 2010 , carrying over a large portion,' + const nbNewTokens = 200 + const iterations = 10 + console.log("Generating", nbNewTokens, "new tokens") + + let inferenceTime = 0 + for (let i = 0; i < iterations; i++) { + const timeStart = performance.now() + const _ = await model.generate(prompt, tokenizer, nbNewTokens) + inferenceTime += performance.now() - timeStart + } + // Overall average includes tokenization, token sampling and de-tokenization + console.log(`Inference time: ${(inferenceTime/ nbNewTokens / iterations).toFixed(2)} ms/token`) + } + await new Promise((resolve, reject) => { + server.once('close', resolve) + server.close(reject) + }) +} + +async function loadWikitextData (task: Task): Promise { + const loader = new NodeTextLoader(task) + const dataSplit: data.DataSplit = { + train: await data.TextData.init(await loader.load('../datasets/wikitext/wiki.train.tokens', {shuffle: true}), task) + } + return dataSplit +} + +// You can run this example with "npm start" from this folder +main(args).catch(console.error) diff --git a/discojs/discojs-core/src/models/gpt/config.ts b/discojs/discojs-core/src/models/gpt/config.ts index 727412563..17515ec9f 100644 --- a/discojs/discojs-core/src/models/gpt/config.ts +++ b/discojs/discojs-core/src/models/gpt/config.ts @@ -1,11 +1,11 @@ type ModelType = - | 'gpt2' - | 'gpt2-medium' - | 'gpt2-large' - | 'gpt2-xl' - | 'gpt-mini' - | 'gpt-micro' - | 'gpt-nano' + | 'gpt2' + | 'gpt2-medium' + | 'gpt2-large' + | 'gpt2-xl' + | 'gpt-mini' + | 'gpt-micro' + | 'gpt-nano' export interface GPTConfig { lr: number @@ -30,7 +30,7 @@ export interface GPTConfig { nHead?: number nEmbd?: number } - +// for a benchmark of performance, see https://github.com/epfml/disco/pull/659 export const DEFAULT_CONFIG: Required = { name: 'transformer', lr: 0.001, @@ -77,9 +77,5 @@ export function getModelSizes (modelType: ModelType): Required { return { nLayer: 4, nHead: 4, nEmbd: 128 } case 'gpt-nano': return { nLayer: 3, nHead: 3, nEmbd: 48 } - default: { - const _: never = modelType - throw new Error("should never happen") - } } } diff --git a/discojs/discojs-core/src/models/gpt/index.ts b/discojs/discojs-core/src/models/gpt/index.ts index f4a47a3c4..73ca76cb6 100644 --- a/discojs/discojs-core/src/models/gpt/index.ts +++ b/discojs/discojs-core/src/models/gpt/index.ts @@ -45,16 +45,16 @@ export class GPT extends Model { }; for (let epoch = 0; epoch < epochs; epoch++) { await this.model.fitDataset(trainingData, trainingArgs); - if (logs === undefined) { - throw new Error("epoch didn't gave any logs"); + throw new Error("Epoch didn't gave any logs"); } - const { loss, val_acc, val_loss } = logs; + const { loss, val_acc, val_loss, peakMemory } = logs; if (loss === undefined || isNaN(loss)) { - throw new Error("Invalid training logs"); + throw new Error("Training loss is undefined or nan"); } const structuredLogs: EpochLogs = { epoch, + peakMemory, training: { loss: logs.loss } @@ -67,7 +67,6 @@ export class GPT extends Model { } structuredLogs.validation = { accuracy: logs.val_acc, loss: logs.val_loss} } - yield structuredLogs } } @@ -81,14 +80,13 @@ export class GPT extends Model { return Promise.resolve(ret) } - async generate (input: string, tokenizer: PreTrainedTokenizer, newTokens: number = 10): Promise { + async generate(input: string, tokenizer: PreTrainedTokenizer, newTokens: number = 10): Promise { const { input_ids: tokens } = await tokenizer(input, { return_tensor: false}) as { input_ids: number[] } const generationConfig = { maxNewTokens: newTokens, temperature: 1.0, - doSample: false, - topK: null + doSample: false } const predictedTokens = await this.model.generate(tokens, generationConfig) const generatedWords = tokenizer.decode(predictedTokens[0]) @@ -118,6 +116,17 @@ export class GPT extends Model { config: this.config } } + + [Symbol.dispose](): void{ + console.log("Disposing model") + if (this.model.optimizer !== undefined) { + this.model.optimizer.dispose() + } + // Some tensors are not cleaned up when model.dispose is called + // So we dispose them manually + this.model.disposeRefs() + this.model.dispose() + } } export type GPTSerialization = { diff --git a/discojs/discojs-core/src/models/gpt/layers.ts b/discojs/discojs-core/src/models/gpt/layers.ts index 243de000a..05b720eb4 100644 --- a/discojs/discojs-core/src/models/gpt/layers.ts +++ b/discojs/discojs-core/src/models/gpt/layers.ts @@ -59,13 +59,12 @@ class CausalSelfAttention extends tf.layers.Layer { private readonly dropout: number private readonly bias: boolean private readonly mask: tf.Tensor2D - cAttnKernel?: tf.LayerVariable cAttnBias?: tf.LayerVariable cProjKernel?: tf.LayerVariable cProjBias?: tf.LayerVariable - constructor (private readonly config: CausalSelfAttentionConfig) { + constructor (private readonly config: CausalSelfAttentionConfig, disposalRefs: tf.TensorContainer[], private peakMemory: {value: number}) { super(config) this.nEmbd = config.nEmbd @@ -77,6 +76,7 @@ class CausalSelfAttention extends tf.layers.Layer { // calling bandPart zero out the upper triangular part of the all-ones matrix // from the doc: tf.linalg.band_part(input, -1, 0) ==> Lower triangular part this.mask = tf.linalg.bandPart(tf.ones([config.blockSize, config.blockSize]), -1, 0) + disposalRefs.push(this.mask) // Push a reference to dispose this matrix later } build (): void { @@ -188,7 +188,10 @@ class CausalSelfAttention extends tf.layers.Layer { y = tf.reshape(y, [B, T, C]) y = dense(y, this.cProjKernel, this.cProjBias) y = kwargs.training === true ? tf.dropout(y, this.dropout) : y - + const memoryAllocated = tf.memory().numBytes / 1024 / 1024 / 1024 // GB + if (memoryAllocated > this.peakMemory.value) { + this.peakMemory.value = memoryAllocated + } return y }) } @@ -257,7 +260,7 @@ function MLP (config: MLPConfig): tf.LayersModel { type BlockConfig = CausalSelfAttentionConfig & MLPConfig & { debug: boolean } -function TransformerBlock (conf: BlockConfig): tf.LayersModel { +function TransformerBlock (conf: BlockConfig, disposalRefs: tf.TensorContainer[], peakMemory: {value: number}): tf.LayersModel { const config = Object.assign({ name: 'h' }, conf) const inputs = tf.input({ shape: [config.blockSize, config.nEmbd] }) let x1, x2 @@ -269,7 +272,9 @@ function TransformerBlock (conf: BlockConfig): tf.LayersModel { } // self attention layer x1 = new CausalSelfAttention( - Object.assign({}, config, { name: config.name + '/attn' }) + Object.assign({}, config, { name: config.name + '/attn' }), + disposalRefs, + peakMemory ).apply(x1) // Residual connection x1 = tf.layers.add().apply([inputs, x1 as tf.SymbolicTensor]) @@ -295,7 +300,10 @@ function TransformerBlock (conf: BlockConfig): tf.LayersModel { * @param conf GPTConfig * @returns model, tf.LayersModel, which supports model(inputs), model.predict and model.apply */ -export function GPTArchitecture (config: Required): tf.LayersModel { +export function GPTArchitecture( + config: Required, + disposalRefs: tf.TensorContainer[], + peakMemory: {value: number }): tf.LayersModel { const inputs = tf.input({ shape: [null] }) //Token embedding @@ -325,7 +333,7 @@ export function GPTArchitecture (config: Required): tf.LayersModel { // token and positional embeddings are added together let x = tf.layers.add().apply([tokEmb, posEmb]) - //dropout + // dropout x = tf.layers.dropout({name: 'drop', rate: config.embdDrop}).apply(x) if (config.debug) { x = new LogLayer({ name: 'dropadd' }).apply(x) @@ -334,7 +342,9 @@ export function GPTArchitecture (config: Required): tf.LayersModel { //Apply successively transformer blocks, attention and dense layers for (let i = 0; i < config.nLayer; i++) { x = TransformerBlock( - Object.assign({}, config, { name: config.name + '/h/' + i }) + Object.assign({}, config, { name: config.name + '/h/' + i }), + disposalRefs, + peakMemory ).apply(x) } // Normalization diff --git a/discojs/discojs-core/src/models/gpt/model.ts b/discojs/discojs-core/src/models/gpt/model.ts index b4a88e7eb..13e725749 100644 --- a/discojs/discojs-core/src/models/gpt/model.ts +++ b/discojs/discojs-core/src/models/gpt/model.ts @@ -24,18 +24,34 @@ export declare abstract class Dataset { */ class GPTModel extends tf.LayersModel { protected readonly config: Required + private readonly disposalRefs: tf.TensorContainer[] // Array to store tensor to dispose manually + // Object to pass down to layers to store max memory allocated + // This is an object rather than a primitive to pass the reference + protected peakMemory: { value: number } constructor(partialConfig?: GPTConfig) { - // Complete missing config parameters with default values + // Fill missing config parameters with default values let completeConfig: Required = { ...DEFAULT_CONFIG, ...partialConfig } // Add layer sizes depending on which model has been specified completeConfig = { ...completeConfig, ...getModelSizes(completeConfig.modelType) } // Init the tf.LayersModel and assign it to this - const gpt = GPTArchitecture(completeConfig) + const disposalRefs: tf.TensorContainer[] = [] + const peakMemory: { value: number } = {value: 0} + const gpt = GPTArchitecture(completeConfig, disposalRefs, peakMemory) const { inputs, outputs, name } = gpt super({ inputs, outputs, name }) this.config = completeConfig + this.disposalRefs = disposalRefs + this.peakMemory = peakMemory + } + + // Some tensors are not cleaned up when model.dispose is called + // So we dispose them manually + disposeRefs() { + for (const tensorContainer of this.disposalRefs) { + tf.dispose([tensorContainer]) + } } get getGPTConfig() { @@ -46,24 +62,23 @@ class GPTModel extends tf.LayersModel { this.optimizer = this.config.weightDecay !== 0 ? getCustomAdam(this, this.config.lr, this.config.weightDecay) : tf.train.adam(this.config.lr) + this.peakMemory.value = 0 } async fitDataset(dataset: Dataset, trainingArgs: tf.ModelFitDatasetArgs): Promise { const callbacks = trainingArgs.callbacks as tf.CustomCallbackArgs const evalDataset = trainingArgs.validationData as tf.data.Dataset<{ xs: tf.Tensor2D, ys: tf.Tensor3D }> - await callbacks.onTrainBegin?.() + for (let epoch = 1; epoch <= trainingArgs.epochs; epoch++) { let averageLoss = 0 let iteration = 1 const iterator = await dataset.iterator() + let preprocessingTime = performance.now() + let next = await iterator.next() + preprocessingTime = performance.now() - preprocessingTime - let continueTraining = true - while (continueTraining) { - let preprocessingTime = performance.now() - const next = await iterator.next() - preprocessingTime = performance.now() - preprocessingTime - + while (next.done !== true && iteration <= this.config.maxIter) { let weightUpdateTime = performance.now() await callbacks.onEpochBegin?.(epoch) const { xs, ys } = next.value as { xs: tf.Tensor2D, ys: tf.Tensor3D } @@ -77,39 +92,54 @@ class GPTModel extends tf.LayersModel { } return tf.losses.softmaxCrossEntropy(ys, logits) } - + let backwardPassMemory = 0 const lossTensor = tf.tidy(() => { const { grads, value: lossTensor } = this.optimizer.computeGradients(lossFn) const gradsClipped = clipByGlobalNormObj(grads, 1) this.optimizer.applyGradients(gradsClipped) + backwardPassMemory = tf.memory().numBytes / 1024 / 1024 / 1024 return lossTensor }) const loss = await lossTensor.array() averageLoss += loss - tf.dispose([xs, ys, lossTensor, next.value]) - weightUpdateTime = performance.now() - weightUpdateTime + // Probably never the case. Empirically the attention mechanism always allocates + // more memory than the backward pass + if (backwardPassMemory > this.peakMemory.value) { + this.peakMemory.value = backwardPassMemory + } + tf.dispose([xs, ys, lossTensor]) + + + if ( + evalDataset !== undefined && + this.config.evaluateEvery !== undefined && + iteration % this.config.evaluateEvery == 0 + ){ + const iterationLogs = await evaluate(this, evalDataset, this.config.maxEvalBatches) + console.log(iterationLogs) + } console.log( `Epoch: ${epoch}`, `\tStep: ${iteration} / ${this.config.maxIter}`, `\tLoss: ${loss.toFixed(3)}`, - `\tMemory: ${(tf.memory().numBytes / 1024 / 1024).toFixed(2)} MB`, + `\tPeak memory: ${this.peakMemory.value.toFixed(2)} GB`, `\tNumber of tensors allocated: ${tf.memory().numTensors}`, `\tPreprocessing time: ${preprocessingTime.toFixed(0)} ms`, `\tWeight update time: ${weightUpdateTime.toFixed(0)} ms` ) - - if (evalDataset !== undefined && this.config.evaluateEvery !== undefined - && iteration % this.config.evaluateEvery == 0) { - const logs = await evaluate(this, evalDataset, this.config.maxEvalBatches) - console.log(logs) - } iteration++ - continueTraining = next.done !== true && iteration <= this.config.maxIter + next = await iterator.next() + } + // Memory leak: If we reached the last iteration rather than the end of the dataset, cleanup the tensors + if (next.done != true && iteration > this.config.maxIter) { + const { xs, ys } = next.value as { xs: tf.Tensor2D, ys: tf.Tensor3D } + tf.dispose([xs, ys]) } let logs: tf.Logs = { - 'loss': averageLoss / iteration + 'loss': averageLoss / iteration, + 'peakMemory': this.peakMemory.value } if (evalDataset !== undefined) { logs = { ...logs, ...await evaluate(this, evalDataset, this.config.maxEvalBatches) } @@ -163,39 +193,33 @@ function prepareIdx (idx: tf.TensorLike): tf.Tensor2D { * */ export class GPTForCausalLM extends GPTModel { - async generate (idxRaw: tf.TensorLike, conf: GenerateConfig, act?: (_: { idxNext: number[][], timePerToken: number }) => Promise): Promise { + async generate (idxRaw: tf.TensorLike, conf: GenerateConfig): Promise { const config = Object.assign({}, defaultGenerateConfig, conf) let idx = prepareIdx(idxRaw) for (let step = 0; step < config.maxNewTokens; step++) { - const { idxNext, timePerToken } = this.generateOnce(this, idx, config) + const idxNext = this.generateOnce(this, idx, config) const idxNew = idx.concat(idxNext, 1) tf.dispose(idx) idx = idxNew - const idxNextArr = await idxNext.array() tf.dispose(idxNext) - if (act !== undefined) { - await act({ idxNext: idxNextArr, timePerToken }) - } } const idxArr = await idx.array() tf.dispose(idx) return idxArr } - private generateOnce (model: tf.LayersModel, idx: tf.Tensor2D, config: GenerateConfig): { idxNext: tf.Tensor2D, timePerToken: number } { - let timePerToken = performance.now() - + private generateOnce (model: tf.LayersModel, idx: tf.Tensor2D, config: GenerateConfig): tf.Tensor2D { const idxNext = tf.tidy(() => { + // slice input tokens if longer than context length const blockSize = this.config.blockSize - const idxCond = idx.shape[1] <= blockSize - ? idx : idx.slice([0, -blockSize], [-1, -1]) - - const output = model.predict(idxCond) + idx = idx.shape[1] <= blockSize + ? idx : idx.slice([0, idx.shape[1] - blockSize]) + + const output = model.predict(idx) if (Array.isArray(output)) throw new Error('The model outputs too multiple values') if (output.shape.length !== 3) throw new Error('The model outputs wrong shape') const logits = output as tf.Tensor3D - - timePerToken = performance.now() - timePerToken + const logitsScaled = logits .slice([0, idx.shape[1] - 1, 0]) .reshape([logits.shape[0], logits.shape[2]]) @@ -207,10 +231,6 @@ export class GPTForCausalLM extends GPTModel { return probs.argMax(-1).expandDims(1) } }) - - return { - idxNext, - timePerToken - } + return idxNext } } diff --git a/discojs/discojs-core/src/models/model.ts b/discojs/discojs-core/src/models/model.ts index b4696445a..60b61764c 100644 --- a/discojs/discojs-core/src/models/model.ts +++ b/discojs/discojs-core/src/models/model.ts @@ -13,6 +13,7 @@ export interface EpochLogs { loss: number, accuracy: number }; + peakMemory: number; } // TODO still bound to tfjs @@ -25,7 +26,7 @@ export type Sample = tf.Tensor; * Allow for various implementation of models (various train function, tensor-library, ...) **/ // TODO make it typesafe: same shape of data/input/weights -export abstract class Model { +export abstract class Model implements Disposable{ // TODO don't allow external access but upgrade train to return weights on every epoch /** Return training state */ abstract get weights(): WeightsContainer; @@ -50,4 +51,16 @@ export abstract class Model { /** Predict likely values */ // TODO extract in separated TrainedModel? abstract predict(input: Sample): Promise; + + + /** + * This method is automatically called to cleanup the memory occupied by the model + * when leaving the definition scope if the instance has been defined with the `using` keyword. + * For example: + * function f() { + * using model = new Model(); + * } + * Calling f() will call the model's dispose method when exiting the function. + */ + abstract [Symbol.dispose](): void; } diff --git a/discojs/discojs-core/src/models/tfjs.ts b/discojs/discojs-core/src/models/tfjs.ts index 1b3ffc46d..8646c34cd 100644 --- a/discojs/discojs-core/src/models/tfjs.ts +++ b/discojs/discojs-core/src/models/tfjs.ts @@ -34,23 +34,31 @@ export class TFJS extends Model { ): AsyncGenerator { for (let epoch = 0; epoch < epochs; epoch++) { let logs: tf.Logs | undefined; - + let peakMemory = 0 await this.model.fitDataset(trainingData, { epochs: 1, validationData, - callbacks: { onEpochEnd: (_, cur) => { logs = cur } }, + callbacks: { + onBatchEnd: (_) => { + const currentMemory = tf.memory().numBytes / 1024 / 1024 / 1024 // GB + if (currentMemory > peakMemory) { + peakMemory = currentMemory + } + }, + onEpochEnd: (_, cur) => { logs = cur } + }, }); if (logs === undefined) { - throw new Error("epoch didn't gave any logs"); + throw new Error("Epoch didn't gave any logs"); } const { loss, acc, val_acc, val_loss } = logs; - console.log(logs) if (loss === undefined || isNaN(loss) || acc === undefined || isNaN(acc)) { - throw new Error("Invalid training logs"); + throw new Error("Training loss is undefined or nan"); } const structuredLogs: EpochLogs = { epoch, + peakMemory, training: { loss: logs.loss, accuracy: logs.acc, @@ -61,7 +69,10 @@ export class TFJS extends Model { val_acc === undefined || isNaN(val_acc)) { throw new Error("Invalid validation logs"); } - structuredLogs.validation = { accuracy: logs.val_acc, loss: logs.val_loss} + structuredLogs.validation = { + accuracy: logs.val_acc, + loss: logs.val_loss + } } yield structuredLogs } @@ -103,6 +114,10 @@ export class TFJS extends Model { return await ret } + [Symbol.dispose](): void{ + this.model.dispose() + } + /** * extract wrapped model * diff --git a/docs/examples/package.json b/docs/examples/package.json index d5b05106b..6c75c7c16 100644 --- a/docs/examples/package.json +++ b/docs/examples/package.json @@ -7,6 +7,7 @@ "train": "npm run build && node dist/training.js", "custom_task": "npm run build && node dist/custom_task.js", "language_model": "npm run build && node dist/wikitext.js", + "benchmark_gpt": "npm run build && node dist/benchmark_gpt.js", "build": "tsc", "lint": "npx eslint .", "test": "npm run train" diff --git a/docs/examples/wikitext.ts b/docs/examples/wikitext.ts index 7b46831fb..e76edc65f 100644 --- a/docs/examples/wikitext.ts +++ b/docs/examples/wikitext.ts @@ -47,8 +47,8 @@ async function main(): Promise { // Retrieve the tokenizer used during training const tokenizer = await models.getTaskTokenizer(task) const prompt = 'The game began development in 2010 , carrying over a large portion' - const generations = await model.generate(prompt, tokenizer) - console.log(generations) + const generation = await model.generate(prompt, tokenizer) + console.log(generation) } async function loadWikitextData (task: Task): Promise { diff --git a/package-lock.json b/package-lock.json index e944137c1..dfea7d7e0 100644 --- a/package-lock.json +++ b/package-lock.json @@ -54,7 +54,7 @@ "ws": "8" }, "devDependencies": { - "@tensorflow/tfjs-node": "^4.17.0", + "@tensorflow/tfjs-node": "4", "@types/chai": "4", "@types/mocha": "10", "@types/simple-peer": "9", @@ -2655,30 +2655,30 @@ } }, "node_modules/@volar/language-core": { - "version": "2.2.0-alpha.7", - "resolved": "https://registry.npmjs.org/@volar/language-core/-/language-core-2.2.0-alpha.7.tgz", - "integrity": "sha512-igpp+nTkyl8faVzRJMpSCeA4XlBJ5UVSyc/WGyksmUmP10YbfufbcQCFlxEXv2uMBV+a3L4JVCj+Vju+08FOSA==", + "version": "2.2.0-alpha.8", + "resolved": "https://registry.npmjs.org/@volar/language-core/-/language-core-2.2.0-alpha.8.tgz", + "integrity": "sha512-Ew1Iw7/RIRNuDLn60fWJdOLApAlfTVPxbPiSLzc434PReC9kleYtaa//Wo2WlN1oiRqneW0pWQQV0CwYqaimLQ==", "dev": true, "dependencies": { - "@volar/source-map": "2.2.0-alpha.7" + "@volar/source-map": "2.2.0-alpha.8" } }, "node_modules/@volar/source-map": { - "version": "2.2.0-alpha.7", - "resolved": "https://registry.npmjs.org/@volar/source-map/-/source-map-2.2.0-alpha.7.tgz", - "integrity": "sha512-iIZM2EovdEnr6mMwlsnt4ciix4xz7HSGHyUSviRaY5cii5PMXGHeUU9UDeb+xzLCx8kdk3L5J4z+ts50AhkYcg==", + "version": "2.2.0-alpha.8", + "resolved": "https://registry.npmjs.org/@volar/source-map/-/source-map-2.2.0-alpha.8.tgz", + "integrity": "sha512-E1ZVmXFJ5DU4fWDcWHzi8OLqqReqIDwhXvIMhVdk6+VipfMVv4SkryXu7/rs4GA/GsebcRyJdaSkKBB3OAkIcA==", "dev": true, "dependencies": { "muggle-string": "^0.4.0" } }, "node_modules/@volar/typescript": { - "version": "2.2.0-alpha.7", - "resolved": "https://registry.npmjs.org/@volar/typescript/-/typescript-2.2.0-alpha.7.tgz", - "integrity": "sha512-qy04/hx4UbW1BdPlzaxlH60D4plubcyqdbYM6Y5vZiascZxFowtd6vE39Td9FYzDxwcKgzb/Crvf/ABhdHnuBA==", + "version": "2.2.0-alpha.8", + "resolved": "https://registry.npmjs.org/@volar/typescript/-/typescript-2.2.0-alpha.8.tgz", + "integrity": "sha512-RLbRDI+17CiayHZs9HhSzlH0FhLl/+XK6o2qoiw2o2GGKcyD1aDoY6AcMd44acYncTOrqoTNoY6LuCiRyiJiGg==", "dev": true, "dependencies": { - "@volar/language-core": "2.2.0-alpha.7", + "@volar/language-core": "2.2.0-alpha.8", "path-browserify": "^1.0.1" } }, @@ -2782,12 +2782,12 @@ } }, "node_modules/@vue/language-core": { - "version": "2.0.12", - "resolved": "https://registry.npmjs.org/@vue/language-core/-/language-core-2.0.12.tgz", - "integrity": "sha512-aIStDPt69SHOpiIckGTIIjEz/sXc6ZfCMS5uWYL1AcbcRMhzFCLZscGAVte1+ad+RRFepSpKBjGttyPcgKJ7ww==", + "version": "2.0.13", + "resolved": "https://registry.npmjs.org/@vue/language-core/-/language-core-2.0.13.tgz", + "integrity": "sha512-oQgM+BM66SU5GKtUMLQSQN0bxHFkFpLSSAiY87wVziPaiNQZuKVDt/3yA7GB9PiQw0y/bTNL0bOc0jM/siYjKg==", "dev": true, "dependencies": { - "@volar/language-core": "2.2.0-alpha.7", + "@volar/language-core": "2.2.0-alpha.8", "@vue/compiler-dom": "^3.4.0", "@vue/shared": "^3.4.0", "computeds": "^0.0.1", @@ -11119,13 +11119,13 @@ } }, "node_modules/vue-tsc": { - "version": "2.0.12", - "resolved": "https://registry.npmjs.org/vue-tsc/-/vue-tsc-2.0.12.tgz", - "integrity": "sha512-thlBBWlPYrNdba535oDdxz7PRUufZgRZRVP5Aql5wBVpGSWSeqou4EzFXeKVoZr59lp9hJROubDVzlhACmcEhg==", + "version": "2.0.13", + "resolved": "https://registry.npmjs.org/vue-tsc/-/vue-tsc-2.0.13.tgz", + "integrity": "sha512-a3nL3FvguCWVJUQW/jFrUxdeUtiEkbZoQjidqvMeBK//tuE2w6NWQAbdrEpY2+6nSa4kZoKZp8TZUMtHpjt4mQ==", "dev": true, "dependencies": { - "@volar/typescript": "2.2.0-alpha.7", - "@vue/language-core": "2.0.12", + "@volar/typescript": "2.2.0-alpha.8", + "@vue/language-core": "2.0.13", "semver": "^7.5.4" }, "bin": {