Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Benchmark GPT-tfjs #659

Merged
merged 17 commits into from
May 1, 2024
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 11 additions & 15 deletions discojs/discojs-core/src/models/gpt/config.ts
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
type ModelType =
| 'gpt2'
| 'gpt2-medium'
| 'gpt2-large'
| 'gpt2-xl'
| 'gpt-mini'
| 'gpt-micro'
| 'gpt-nano'
export type GPTModelType =
| 'gpt2'
| 'gpt2-medium'
| 'gpt2-large'
| 'gpt2-xl'
| 'gpt-mini'
| 'gpt-micro'
| 'gpt-nano'

export interface GPTConfig {
lr: number
blockSize: number
vocabSize: number
modelType: ModelType
modelType: GPTModelType
name?: string,
evaluate?: boolean
maxEvalBatches?: number
Expand All @@ -30,7 +30,7 @@ export interface GPTConfig {
nHead?: number
nEmbd?: number
}

// for a benchmark of performance, see https://github.com/epfml/disco/pull/659
export const DEFAULT_CONFIG: Required<GPTConfig> = {
name: 'transformer',
lr: 0.001,
Expand Down Expand Up @@ -61,7 +61,7 @@ export type ModelSize = {
nEmbd: number
}

export function getModelSizes (modelType: ModelType): Required<ModelSize> {
export function getModelSizes (modelType: GPTModelType): Required<ModelSize> {
switch (modelType) {
case 'gpt2':
return { nLayer: 12, nHead: 12, nEmbd: 768 }
Expand All @@ -77,9 +77,5 @@ export function getModelSizes (modelType: ModelType): Required<ModelSize> {
return { nLayer: 4, nHead: 4, nEmbd: 128 }
case 'gpt-nano':
return { nLayer: 3, nHead: 3, nEmbd: 48 }
default: {
const _: never = modelType
throw new Error("should never happen")
}
}
}
2 changes: 1 addition & 1 deletion discojs/discojs-core/src/models/gpt/gpt.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ describe('gpt-tfjs', function() {
const model = new GPT(config)
const logGenerator = model.train(tokenDataset, undefined, 5) // 5 epochs
for await (const _ of logGenerator); // Await the end of training
const generation = await model.generate("Lorem ipsum dolor", tokenizer, 1)
const { generation, avgTokenTime: _ } = await model.generate("Lorem ipsum dolor", tokenizer, 1)
console.log(generation)
expect(generation).equal(data) // Assert that the model completes 'Lorem ipsum dolor' with 'sit'
})
Expand Down
35 changes: 27 additions & 8 deletions discojs/discojs-core/src/models/gpt/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -45,13 +45,12 @@ export class GPT extends Model {
};
for (let epoch = 0; epoch < epochs; epoch++) {
await this.model.fitDataset(trainingData, trainingArgs);

if (logs === undefined) {
throw new Error("epoch didn't gave any logs");
}
const { loss, val_acc, val_loss } = logs;
const { loss, val_acc, val_loss, weightUpdateTime, peakMemory } = logs;
if (loss === undefined || isNaN(loss)) {
throw new Error("Invalid training logs");
throw new Error("Training loss is undefined or nan");
}
const structuredLogs: EpochLogs = {
epoch,
Expand All @@ -67,6 +66,12 @@ export class GPT extends Model {
}
structuredLogs.validation = { accuracy: logs.val_acc, loss: logs.val_loss}
}
if (weightUpdateTime !== undefined && !isNaN(weightUpdateTime)) {
structuredLogs['weightUpdateTime'] = weightUpdateTime
}
if (peakMemory !== undefined && !isNaN(peakMemory)) {
structuredLogs['peakMemory'] = peakMemory
}

yield structuredLogs
}
Expand All @@ -81,18 +86,24 @@ export class GPT extends Model {
return Promise.resolve(ret)
}

async generate (input: string, tokenizer: PreTrainedTokenizer, newTokens: number = 10): Promise<string> {
async generate(input: string, tokenizer: PreTrainedTokenizer, newTokens: number = 10):
Promise<{ generation: string, avgTokenTime: number }> {
const { input_ids: tokens } = await tokenizer(input, { return_tensor: false}) as { input_ids: number[] }

const generationConfig = {
maxNewTokens: newTokens,
temperature: 1.0,
doSample: false,
topK: null
doSample: false
}
const predictedTokens = await this.model.generate(tokens, generationConfig)
let avgTimePerToken = 0
const predictedTokens = await this.model.generate(tokens, generationConfig, (res) => {
avgTimePerToken += res.timePerToken
})
const generatedWords = tokenizer.decode(predictedTokens[0])
return generatedWords
return {
generation: generatedWords,
avgTokenTime: avgTimePerToken / generationConfig.maxNewTokens
}
}

get config (): Required<GPTConfig> {
Expand All @@ -118,6 +129,14 @@ export class GPT extends Model {
config: this.config
}
}

dispose(): void {
this.model.optimizer.dispose()
// Some tensors are not cleaned up when model.dispose is called
// So we dispose them manually
this.model.disposeRefs()
this.model.dispose()
}
}

export type GPTSerialization = {
Expand Down
26 changes: 18 additions & 8 deletions discojs/discojs-core/src/models/gpt/layers.ts
Original file line number Diff line number Diff line change
Expand Up @@ -59,13 +59,12 @@ class CausalSelfAttention extends tf.layers.Layer {
private readonly dropout: number
private readonly bias: boolean
private readonly mask: tf.Tensor2D

cAttnKernel?: tf.LayerVariable
cAttnBias?: tf.LayerVariable
cProjKernel?: tf.LayerVariable
cProjBias?: tf.LayerVariable

constructor (private readonly config: CausalSelfAttentionConfig) {
constructor (private readonly config: CausalSelfAttentionConfig, disposalRefs: tf.TensorContainer[], private peakMemory: {value: number}) {
super(config)

this.nEmbd = config.nEmbd
Expand All @@ -77,6 +76,7 @@ class CausalSelfAttention extends tf.layers.Layer {
// calling bandPart zero out the upper triangular part of the all-ones matrix
// from the doc: tf.linalg.band_part(input, -1, 0) ==> Lower triangular part
this.mask = tf.linalg.bandPart(tf.ones([config.blockSize, config.blockSize]), -1, 0)
disposalRefs.push(this.mask) // Push a reference to dispose this matrix later
}

build (): void {
Expand Down Expand Up @@ -188,7 +188,10 @@ class CausalSelfAttention extends tf.layers.Layer {
y = tf.reshape(y, [B, T, C])
y = dense(y, this.cProjKernel, this.cProjBias)
y = kwargs.training === true ? tf.dropout(y, this.dropout) : y

const memoryAllocated = tf.memory().numBytes / 1024 / 1024 / 1024 // GB
if (memoryAllocated > this.peakMemory.value) {
this.peakMemory.value = memoryAllocated
}
return y
})
}
Expand Down Expand Up @@ -257,7 +260,7 @@ function MLP (config: MLPConfig): tf.LayersModel {

type BlockConfig = CausalSelfAttentionConfig & MLPConfig & { debug: boolean }

function TransformerBlock (conf: BlockConfig): tf.LayersModel {
function TransformerBlock (conf: BlockConfig, disposalRefs: tf.TensorContainer[], peakMemory: {value: number}): tf.LayersModel {
const config = Object.assign({ name: 'h' }, conf)
const inputs = tf.input({ shape: [config.blockSize, config.nEmbd] })
let x1, x2
Expand All @@ -269,7 +272,9 @@ function TransformerBlock (conf: BlockConfig): tf.LayersModel {
}
// self attention layer
x1 = new CausalSelfAttention(
Object.assign({}, config, { name: config.name + '/attn' })
Object.assign({}, config, { name: config.name + '/attn' }),
disposalRefs,
peakMemory
).apply(x1)
// Residual connection
x1 = tf.layers.add().apply([inputs, x1 as tf.SymbolicTensor])
Expand All @@ -295,7 +300,10 @@ function TransformerBlock (conf: BlockConfig): tf.LayersModel {
* @param conf GPTConfig
* @returns model, tf.LayersModel, which supports model(inputs), model.predict and model.apply
*/
export function GPTArchitecture (config: Required<GPTConfig>): tf.LayersModel {
export function GPTArchitecture(
config: Required<GPTConfig>,
disposalRefs: tf.TensorContainer[],
peakMemory: {value: number }): tf.LayersModel {
const inputs = tf.input({ shape: [null] })

//Token embedding
Expand Down Expand Up @@ -325,7 +333,7 @@ export function GPTArchitecture (config: Required<GPTConfig>): tf.LayersModel {

// token and positional embeddings are added together
let x = tf.layers.add().apply([tokEmb, posEmb])
//dropout
// dropout
x = tf.layers.dropout({name: 'drop', rate: config.embdDrop}).apply(x)
if (config.debug) {
x = new LogLayer({ name: 'dropadd' }).apply(x)
Expand All @@ -334,7 +342,9 @@ export function GPTArchitecture (config: Required<GPTConfig>): tf.LayersModel {
//Apply successively transformer blocks, attention and dense layers
for (let i = 0; i < config.nLayer; i++) {
x = TransformerBlock(
Object.assign({}, config, { name: config.name + '/h/' + i })
Object.assign({}, config, { name: config.name + '/h/' + i }),
disposalRefs,
peakMemory
).apply(x)
}
// Normalization
Expand Down
Loading