Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Benchmark GPT-tfjs #659

Merged
merged 17 commits into from
May 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion cli/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ The CLI lets one use DISCO in standalone manner (i.e. without running a server o
For example, the following command trains a model on CIFAR10, using 4 federated clients for 15 epochs with a round duration of 5 epochs (see [DISCOJS.md](../docs/DISCOJS.md#rounds) for more information on rounds)

> [!NOTE]
> Make sure you first ran `./get_training_data.sh` (in the root folder) to download training data.
> Make sure you first ran `./datasets/populate` (from the root folder) to download training data.

```
# From the root folder
Expand Down Expand Up @@ -35,3 +35,12 @@ You should now be able to run your task as follows:
```
npm -w cli start -- --task your_task --numberOfUsers 4 --epochs 15 --roundDuration 5
```

## Benchmarking GPT-TF.js

The CLI also allows benchmarking the time and memory requirements of the gpt-tfjs implementation in DISCO. The last benchmark has been reported in [this PR](https://github.com/epfml/disco/pull/659).
CLI options can be listed with `npm -w cli run benchmark_gpt -- -h`.

To benchmark model training, you can run `npm -w cli run benchmark_gpt -- --modelType gpt-nano --contextLength 128 --batchSize 8`.

For inference run `npm -w cli run benchmark_gpt -- --inference --modelPath <path to trained model json file>`. You can use the `docs/example/wikitext` example script to train a model. The model needs to be trained on the wikitext default task to ensure that model parameters such as vocab size, tokenizer, max sequence length are the same between training and inference.
1 change: 1 addition & 0 deletions cli/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
"scripts": {
"watch": "nodemon --ext ts --ignore dist --watch ../discojs/discojs-node/dist --watch ../server/dist --watch . --exec npm run",
"start": "npm run build && node dist/cli.js",
"benchmark_gpt": "npm run build && node dist/benchmark_gpt.js",
"build": "tsc",
"lint": "npx eslint .",
"test": ": nothing"
Expand Down
131 changes: 131 additions & 0 deletions cli/src/benchmark_gpt.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
import { parse } from 'ts-command-line-args';
import type { Task } from '@epfml/discojs-core'
import { fetchTasks, data, models } from '@epfml/discojs-core'
import { NodeTextLoader, loadModelFromDisk } from '@epfml/discojs-node'
import { startServer } from '@epfml/disco-server'

interface CLIArguments{
modelType?: string; // 'gpt-nano', 'gpt-micro', 'gpt-mini', 'gpt2'
contextLength?: number; // 128, 256, 512, 1024, 2048
batchSize?: number; // 8, 16, 32, 64
inference?: boolean; // benchmark inference if true, training otherwise
modelPath?: string;
help?: boolean // print help
}

const parsedArgs = parse<CLIArguments>({
modelType: { type: String, optional: true, description: "A GPT architecture: 'gpt-nano', 'gpt-micro', 'gpt-mini', 'gpt2'" },
contextLength: { type: Number, optional: true, description: "The maximum input sequence length to train the model on" },
batchSize: { type: Number, optional: true, description: "The model training bat size" },
inference: { type: Boolean, optional: true, description: "Whether to benchmark the model inference or training" },
modelPath: { type: String, optional: true, description: "If benchmarking inference, the path to the trained model" },
help: { type: Boolean, optional: true, alias: 'h', description: 'Prints this usage guide' },
}, {helpArg: 'help'});

const defaultArgs: Required<CLIArguments> = {
modelType: 'gpt-nano',
contextLength: 128,
batchSize: 8,
inference: false,
modelPath: 'models/model.json',
help: false
}

// Fill parsed args with default args
const args = { ...defaultArgs, ...parsedArgs }

/**
* Benchmark results are reported in https://github.com/epfml/disco/pull/659
*/

async function main(args: Required<CLIArguments>): Promise<void> {
const { inference: benchmarkInference, modelType,
contextLength, batchSize, modelPath } = args

// Launch a server instance
const [server, url] = await startServer()

// const url = new URL('http://localhost:8080')

// Fetch the wikitext task from the server
const tasks = await fetchTasks(url)
const task = tasks.get('wikitext-103')
if (task === undefined) { throw new Error('task not found') }

/**
* Training benchmark
*/
if (!benchmarkInference) {
// Benchmark parameters
const epoch = 1
const iterationsPerEpoch = 10

const config: models.GPTConfig = {
modelType: modelType as models.GPTConfig['modelType'],
maxIter: iterationsPerEpoch,
blockSize: contextLength,
lr: 0.0001,
vocabSize: 50258 // default wikitext task uses the gpt2 tokenizer with vocabSize 50258
}

// Load the dataset after setting the Task batch size and max sequence length
// to make sure the dataset is batched and tokenized correctly
task.trainingInformation.batchSize = batchSize
task.trainingInformation.maxSequenceLength = contextLength
const dataset = await loadWikitextData(task)
const preprocessedDataset = dataset.train.preprocess().batch().dataset

// Init and train the model
const model = new models.GPT(config)
console.log(`\tmodel type ${modelType} \n\tbatch size ${batchSize} \n\tcontext length ${contextLength}`)

let epochTime = performance.now()
const logGenerator = model.train(preprocessedDataset, undefined, epoch)
for await (const logs of logGenerator) {
epochTime = (performance.now() - epochTime)
const msPerToken = epochTime / (batchSize * contextLength * iterationsPerEpoch * epoch)
console.log(`\t\tTraining time: ${msPerToken.toFixed(2)} ms/token <br> ${logs.peakMemory.toFixed(2)} GB`)
}

/**
* Inference benchmark
*/
} else {
const model = await loadModelFromDisk(modelPath)
if (!(model instanceof models.GPT)){
throw new Error("Loaded model isn't a GPT model")
}
// Retrieve the tokenizer used during training
const tokenizer = await models.getTaskTokenizer(task)

// Benchmark parameters
const prompt = 'The game began development in 2010 , carrying over a large portion, The game began development in 2010 , carrying over a large portion, The game began development in 2010 , carrying over a large portion,'
const nbNewTokens = 200
const iterations = 10
console.log("Generating", nbNewTokens, "new tokens")

let inferenceTime = 0
for (let i = 0; i < iterations; i++) {
const timeStart = performance.now()
const _ = await model.generate(prompt, tokenizer, nbNewTokens)
inferenceTime += performance.now() - timeStart
}
// Overall average includes tokenization, token sampling and de-tokenization
console.log(`Inference time: ${(inferenceTime/ nbNewTokens / iterations).toFixed(2)} ms/token`)
}
await new Promise((resolve, reject) => {
server.once('close', resolve)
server.close(reject)
})
}

async function loadWikitextData (task: Task): Promise<data.DataSplit> {
const loader = new NodeTextLoader(task)
const dataSplit: data.DataSplit = {
train: await data.TextData.init(await loader.load('../datasets/wikitext/wiki.train.tokens', {shuffle: true}), task)
}
return dataSplit
}

// You can run this example with "npm start" from this folder
main(args).catch(console.error)
20 changes: 8 additions & 12 deletions discojs/discojs-core/src/models/gpt/config.ts
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
type ModelType =
| 'gpt2'
| 'gpt2-medium'
| 'gpt2-large'
| 'gpt2-xl'
| 'gpt-mini'
| 'gpt-micro'
| 'gpt-nano'
| 'gpt2'
| 'gpt2-medium'
| 'gpt2-large'
| 'gpt2-xl'
| 'gpt-mini'
| 'gpt-micro'
| 'gpt-nano'

export interface GPTConfig {
lr: number
Expand All @@ -30,7 +30,7 @@ export interface GPTConfig {
nHead?: number
nEmbd?: number
}

// for a benchmark of performance, see https://github.com/epfml/disco/pull/659
export const DEFAULT_CONFIG: Required<GPTConfig> = {
name: 'transformer',
lr: 0.001,
Expand Down Expand Up @@ -77,9 +77,5 @@ export function getModelSizes (modelType: ModelType): Required<ModelSize> {
return { nLayer: 4, nHead: 4, nEmbd: 128 }
case 'gpt-nano':
return { nLayer: 3, nHead: 3, nEmbd: 48 }
default: {
const _: never = modelType
throw new Error("should never happen")
}
}
}
25 changes: 17 additions & 8 deletions discojs/discojs-core/src/models/gpt/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -45,16 +45,16 @@ export class GPT extends Model {
};
for (let epoch = 0; epoch < epochs; epoch++) {
await this.model.fitDataset(trainingData, trainingArgs);

if (logs === undefined) {
throw new Error("epoch didn't gave any logs");
throw new Error("Epoch didn't gave any logs");
}
const { loss, val_acc, val_loss } = logs;
const { loss, val_acc, val_loss, peakMemory } = logs;
if (loss === undefined || isNaN(loss)) {
throw new Error("Invalid training logs");
throw new Error("Training loss is undefined or nan");
}
const structuredLogs: EpochLogs = {
epoch,
peakMemory,
training: {
loss: logs.loss
}
Expand All @@ -67,7 +67,6 @@ export class GPT extends Model {
}
structuredLogs.validation = { accuracy: logs.val_acc, loss: logs.val_loss}
}

yield structuredLogs
}
}
Expand All @@ -81,14 +80,13 @@ export class GPT extends Model {
return Promise.resolve(ret)
}

async generate (input: string, tokenizer: PreTrainedTokenizer, newTokens: number = 10): Promise<string> {
async generate(input: string, tokenizer: PreTrainedTokenizer, newTokens: number = 10): Promise<string> {
const { input_ids: tokens } = await tokenizer(input, { return_tensor: false}) as { input_ids: number[] }

const generationConfig = {
maxNewTokens: newTokens,
temperature: 1.0,
doSample: false,
topK: null
doSample: false
}
const predictedTokens = await this.model.generate(tokens, generationConfig)
const generatedWords = tokenizer.decode(predictedTokens[0])
Expand Down Expand Up @@ -118,6 +116,17 @@ export class GPT extends Model {
config: this.config
}
}

[Symbol.dispose](): void{
console.log("Disposing model")
if (this.model.optimizer !== undefined) {
this.model.optimizer.dispose()
}
// Some tensors are not cleaned up when model.dispose is called
// So we dispose them manually
this.model.disposeRefs()
this.model.dispose()
}
}

export type GPTSerialization = {
Expand Down
26 changes: 18 additions & 8 deletions discojs/discojs-core/src/models/gpt/layers.ts
Original file line number Diff line number Diff line change
Expand Up @@ -59,13 +59,12 @@ class CausalSelfAttention extends tf.layers.Layer {
private readonly dropout: number
private readonly bias: boolean
private readonly mask: tf.Tensor2D

cAttnKernel?: tf.LayerVariable
cAttnBias?: tf.LayerVariable
cProjKernel?: tf.LayerVariable
cProjBias?: tf.LayerVariable

constructor (private readonly config: CausalSelfAttentionConfig) {
constructor (private readonly config: CausalSelfAttentionConfig, disposalRefs: tf.TensorContainer[], private peakMemory: {value: number}) {
super(config)

this.nEmbd = config.nEmbd
Expand All @@ -77,6 +76,7 @@ class CausalSelfAttention extends tf.layers.Layer {
// calling bandPart zero out the upper triangular part of the all-ones matrix
// from the doc: tf.linalg.band_part(input, -1, 0) ==> Lower triangular part
this.mask = tf.linalg.bandPart(tf.ones([config.blockSize, config.blockSize]), -1, 0)
disposalRefs.push(this.mask) // Push a reference to dispose this matrix later
}

build (): void {
Expand Down Expand Up @@ -188,7 +188,10 @@ class CausalSelfAttention extends tf.layers.Layer {
y = tf.reshape(y, [B, T, C])
y = dense(y, this.cProjKernel, this.cProjBias)
y = kwargs.training === true ? tf.dropout(y, this.dropout) : y

const memoryAllocated = tf.memory().numBytes / 1024 / 1024 / 1024 // GB
if (memoryAllocated > this.peakMemory.value) {
this.peakMemory.value = memoryAllocated
}
return y
})
}
Expand Down Expand Up @@ -257,7 +260,7 @@ function MLP (config: MLPConfig): tf.LayersModel {

type BlockConfig = CausalSelfAttentionConfig & MLPConfig & { debug: boolean }

function TransformerBlock (conf: BlockConfig): tf.LayersModel {
function TransformerBlock (conf: BlockConfig, disposalRefs: tf.TensorContainer[], peakMemory: {value: number}): tf.LayersModel {
const config = Object.assign({ name: 'h' }, conf)
const inputs = tf.input({ shape: [config.blockSize, config.nEmbd] })
let x1, x2
Expand All @@ -269,7 +272,9 @@ function TransformerBlock (conf: BlockConfig): tf.LayersModel {
}
// self attention layer
x1 = new CausalSelfAttention(
Object.assign({}, config, { name: config.name + '/attn' })
Object.assign({}, config, { name: config.name + '/attn' }),
disposalRefs,
peakMemory
).apply(x1)
// Residual connection
x1 = tf.layers.add().apply([inputs, x1 as tf.SymbolicTensor])
Expand All @@ -295,7 +300,10 @@ function TransformerBlock (conf: BlockConfig): tf.LayersModel {
* @param conf GPTConfig
* @returns model, tf.LayersModel, which supports model(inputs), model.predict and model.apply
*/
export function GPTArchitecture (config: Required<GPTConfig>): tf.LayersModel {
export function GPTArchitecture(
config: Required<GPTConfig>,
disposalRefs: tf.TensorContainer[],
peakMemory: {value: number }): tf.LayersModel {
const inputs = tf.input({ shape: [null] })

//Token embedding
Expand Down Expand Up @@ -325,7 +333,7 @@ export function GPTArchitecture (config: Required<GPTConfig>): tf.LayersModel {

// token and positional embeddings are added together
let x = tf.layers.add().apply([tokEmb, posEmb])
//dropout
// dropout
x = tf.layers.dropout({name: 'drop', rate: config.embdDrop}).apply(x)
if (config.debug) {
x = new LogLayer({ name: 'dropadd' }).apply(x)
Expand All @@ -334,7 +342,9 @@ export function GPTArchitecture (config: Required<GPTConfig>): tf.LayersModel {
//Apply successively transformer blocks, attention and dense layers
for (let i = 0; i < config.nLayer; i++) {
x = TransformerBlock(
Object.assign({}, config, { name: config.name + '/h/' + i })
Object.assign({}, config, { name: config.name + '/h/' + i }),
disposalRefs,
peakMemory
).apply(x)
}
// Normalization
Expand Down
Loading