Skip to content

Commit

Permalink
fix: correct special token matching & counting
Browse files Browse the repository at this point in the history
  • Loading branch information
niieani committed Nov 13, 2024
1 parent 6030d91 commit 3547826
Show file tree
Hide file tree
Showing 43 changed files with 131 additions and 28 deletions.
37 changes: 17 additions & 20 deletions src/BytePairEncodingCore.ts
Original file line number Diff line number Diff line change
Expand Up @@ -87,20 +87,18 @@ export class BytePairEncodingCore {
let lastTokenLength = 0

while (true) {
const nextSpecialStartIndex = this.findNextSpecialStartIndex(
const nextSpecialMatch = this.findNextSpecialToken(
text,
allowedSpecial,
startIndex,
)
const nextSpecialStartIndex = nextSpecialMatch?.[0]

const endIndex =
nextSpecialStartIndex !== undefined
? nextSpecialStartIndex
: text.length
const endIndex = nextSpecialStartIndex ?? text.length

const textSegment = text.slice(startIndex, endIndex - startIndex)
const textBeforeSpecial = text.slice(startIndex, endIndex)

for (const [match] of textSegment.matchAll(this.tokenSplitRegex)) {
for (const [match] of textBeforeSpecial.matchAll(this.tokenSplitRegex)) {
const token = this.getBpeRankFromString(match)
if (token !== undefined) {
lastTokenLength = 1
Expand All @@ -115,7 +113,7 @@ export class BytePairEncodingCore {
}

if (nextSpecialStartIndex !== undefined) {
const specialToken = text.slice(Math.max(0, nextSpecialStartIndex))
const specialToken = nextSpecialMatch![1]
const specialTokenValue = this.specialTokensEncoder.get(specialToken)
if (specialTokenValue === undefined) {
throw new Error(
Expand All @@ -124,7 +122,7 @@ export class BytePairEncodingCore {
}
yield [specialTokenValue]
startIndex = nextSpecialStartIndex + specialToken.length
lastTokenLength = 0
lastTokenLength = 1
} else {
break
}
Expand All @@ -139,20 +137,18 @@ export class BytePairEncodingCore {

// eslint-disable-next-line no-constant-condition
while (true) {
const nextSpecialStartIndex = this.findNextSpecialStartIndex(
const nextSpecialMatch = this.findNextSpecialToken(
text,
allowedSpecial,
startIndex,
)
const nextSpecialStartIndex = nextSpecialMatch?.[0]

const endIndex =
nextSpecialStartIndex !== undefined
? nextSpecialStartIndex
: text.length
const endIndex = nextSpecialStartIndex ?? text.length

const textSegment = text.slice(startIndex, endIndex - startIndex)
const textBeforeSpecial = text.slice(startIndex, endIndex)

for (const [match] of textSegment.matchAll(this.tokenSplitRegex)) {
for (const [match] of textBeforeSpecial.matchAll(this.tokenSplitRegex)) {
const token = this.getBpeRankFromString(match)
if (token !== undefined) {
tokensArray.push(token)
Expand All @@ -165,7 +161,7 @@ export class BytePairEncodingCore {
}

if (nextSpecialStartIndex !== undefined) {
const specialToken = text.slice(Math.max(0, nextSpecialStartIndex))
const specialToken = nextSpecialMatch![1]
const specialTokenValue = this.specialTokensEncoder.get(specialToken)
if (specialTokenValue === undefined) {
throw new Error(
Expand Down Expand Up @@ -303,11 +299,11 @@ export class BytePairEncodingCore {
return -1
}

private findNextSpecialStartIndex(
private findNextSpecialToken(
text: string,
allowedSpecial: Set<string> | undefined,
startIndex: number,
): number | undefined {
): [startIndex: number, token: string] | undefined {
let searchIndex = startIndex

// eslint-disable-next-line no-constant-condition
Expand All @@ -323,7 +319,8 @@ export class BytePairEncodingCore {
const specialToken = nextSpecialMatch[0]

if (allowedSpecial?.has(specialToken)) {
return nextSpecialMatch.index + searchIndex
const specialTokenStartIndex = nextSpecialMatch.index + searchIndex
return [specialTokenStartIndex, specialToken]
}

searchIndex = nextSpecialMatch.index + searchIndex + 1
Expand Down
37 changes: 36 additions & 1 deletion src/GptEncoding.test.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import fs from 'fs'
import path from 'path'
import { ALL_SPECIAL_TOKENS } from './constants.js'
import { type ChatMessage, GptEncoding } from './GptEncoding.js'
import {
type ChatModelName,
Expand Down Expand Up @@ -61,6 +62,25 @@ const results = {
r50k_base: sharedResults,
} satisfies Record<EncodingName, unknown>

const offsetPrompts = [
// Basic prompt with "hello world"
'hello world',

// Basic prompt with special token "<|endoftext|>"
'hello world<|endoftext|> green cow',

// Chinese text: "我非常渴望与人工智能一起工作"
'我非常渴望与人工智能一起工作',

// Contains the interesting tokens b'\xe0\xae\xbf\xe0\xae' and b'\xe0\xaf\x8d\xe0\xae'
// in which \xe0 is the start of a 3-byte UTF-8 character
'நடிகர் சூர்யா',

// Contains the interesting token b'\xa0\xe9\x99\xa4'
// in which \xe9 is the start of a 3-byte UTF-8 character and \xa0 is a continuation byte
' Ġ除',
]

// eslint-disable-next-line @typescript-eslint/no-use-before-define
const testPlans = loadTestPlans()

Expand All @@ -75,9 +95,16 @@ describe.each(encodingNames)('%s', (encodingName: EncodingName) => {
isWithinTokenLimit,
} = encoding

const result = results[encodingName]
describe('encode and decode', () => {
it.each(offsetPrompts)('offset prompt: %s', (str) => {
expect(
decode(encode(str, { allowedSpecial: ALL_SPECIAL_TOKENS })),
).toEqual(str)
})
})

describe('basic functionality', () => {
const result = results[encodingName]
it('empty string', () => {
const str = ''
expect(encode(str)).toEqual([])
Expand Down Expand Up @@ -131,6 +158,14 @@ describe.each(encodingNames)('%s', (encodingName: EncodingName) => {
})
})

it('encodes and decodes special tokens', () => {
const str = 'hello <|endoftext|> world'
const encoded = encode(str, {
allowedSpecial: ALL_SPECIAL_TOKENS,
})
expect(decode(encoded)).toEqual(str)
})

async function* getHelloWorldTokensAsync() {
const str = 'hello 👋 world 🌍'
for (const token of result[str]) {
Expand Down
45 changes: 38 additions & 7 deletions src/GptEncoding.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
/* eslint-disable @typescript-eslint/member-ordering */
/* eslint-disable no-param-reassign */
import { BytePairEncodingCore, decoder } from './BytePairEncodingCore.js'
import { ALL_SPECIAL_TOKENS } from './constants.js'
import {
type ChatModelName,
type ChatParameters,
Expand Down Expand Up @@ -29,11 +30,19 @@ import {
import { endsWithIncompleteUtfPairSurrogate } from './utfUtil.js'
import { getMaxValueFromMap, getSpecialTokenRegex } from './util.js'

export const ALL_SPECIAL_TOKENS = 'all'

export interface EncodeOptions {
allowedSpecial?: Set<string>
disallowedSpecial?: Set<string>
/**
* A list of special tokens that are allowed in the input.
* If set to 'all', all special tokens are allowed except those in disallowedSpecial.
* @default undefined
*/
allowedSpecial?: Set<string> | typeof ALL_SPECIAL_TOKENS
/**
* A list of special tokens that are disallowed in the input.
* If set to 'all', all special tokens are disallowed except those in allowedSpecial.
* @default 'all'
*/
disallowedSpecial?: Set<string> | typeof ALL_SPECIAL_TOKENS
}

export interface ChatMessage {
Expand Down Expand Up @@ -168,15 +177,37 @@ export class GptEncoding {
}: EncodeOptions = {}): SpecialTokenConfig {
let regexPattern: RegExp | undefined

if (allowedSpecial?.has(ALL_SPECIAL_TOKENS)) {
if (
allowedSpecial === ALL_SPECIAL_TOKENS ||
allowedSpecial?.has(ALL_SPECIAL_TOKENS)
) {
allowedSpecial = new Set(this.specialTokensSet)
const allowedSpecialSet = allowedSpecial
if (disallowedSpecial === ALL_SPECIAL_TOKENS) {
throw new Error(
'allowedSpecial and disallowedSpecial cannot both be set to "all".',
)
}
if (typeof disallowedSpecial === 'object') {
// remove any special tokens that are disallowed
disallowedSpecial.forEach((val) => allowedSpecialSet.delete(val))
} else {
// all special tokens are allowed, and no 'disallowedSpecial' is provided
disallowedSpecial = new Set()
}
}

if (!disallowedSpecial || disallowedSpecial.has(ALL_SPECIAL_TOKENS)) {
if (
!disallowedSpecial ||
disallowedSpecial === ALL_SPECIAL_TOKENS ||
disallowedSpecial.has(ALL_SPECIAL_TOKENS)
) {
// by default, all special tokens are disallowed
disallowedSpecial = new Set(this.specialTokensSet)
const disallowedSpecialSet = disallowedSpecial
if (allowedSpecial?.size) {
allowedSpecial.forEach((val) => disallowedSpecial!.delete(val))
allowedSpecial.forEach((val) => disallowedSpecialSet.delete(val))
// disallowed takes precedence over allowed
disallowedSpecial.forEach((val) => allowedSpecial.delete(val))
regexPattern = getSpecialTokenRegex(disallowedSpecial)
} else {
Expand Down
1 change: 1 addition & 0 deletions src/constants.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
export const ALL_SPECIAL_TOKENS = 'all'
1 change: 1 addition & 0 deletions src/encoding/cl100k_base.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import bpeRanks from '../bpeRanks/cl100k_base.js'
import { GptEncoding } from '../GptEncoding.js'

export * from '../constants.js'
export * from '../specialTokens.js'

const api = GptEncoding.getEncodingApi('cl100k_base', () => bpeRanks)
Expand Down
1 change: 1 addition & 0 deletions src/encoding/o200k_base.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import bpeRanks from '../bpeRanks/o200k_base.js'
import { GptEncoding } from '../GptEncoding.js'

export * from '../constants.js'
export * from '../specialTokens.js'

const api = GptEncoding.getEncodingApi('o200k_base', () => bpeRanks)
Expand Down
1 change: 1 addition & 0 deletions src/encoding/p50k_base.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import bpeRanks from '../bpeRanks/p50k_base.js'
import { GptEncoding } from '../GptEncoding.js'

export * from '../constants.js'
export * from '../specialTokens.js'

const api = GptEncoding.getEncodingApi('p50k_base', () => bpeRanks)
Expand Down
1 change: 1 addition & 0 deletions src/encoding/p50k_edit.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import bpeRanks from '../bpeRanks/p50k_base.js'
import { GptEncoding } from '../GptEncoding.js'

export * from '../constants.js'
export * from '../specialTokens.js'

const api = GptEncoding.getEncodingApi('p50k_edit', () => bpeRanks)
Expand Down
1 change: 1 addition & 0 deletions src/encoding/r50k_base.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import bpeRanks from '../bpeRanks/r50k_base.js'
import { GptEncoding } from '../GptEncoding.js'

export * from '../constants.js'
export * from '../specialTokens.js'

const api = GptEncoding.getEncodingApi('r50k_base', () => bpeRanks)
Expand Down
1 change: 1 addition & 0 deletions src/model/gpt-3.5-turbo-0125.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import bpeRanks from '../bpeRanks/cl100k_base.js'
import { GptEncoding } from '../GptEncoding.js'

export * from '../constants.js'
export * from '../specialTokens.js'
// prettier-ignore
const api = GptEncoding.getEncodingApiForModel('gpt-3.5-turbo-0125', () => bpeRanks)
Expand Down
1 change: 1 addition & 0 deletions src/model/gpt-3.5-turbo-0301.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import bpeRanks from '../bpeRanks/cl100k_base.js'
import { GptEncoding } from '../GptEncoding.js'

export * from '../constants.js'
export * from '../specialTokens.js'
// prettier-ignore
const api = GptEncoding.getEncodingApiForModel('gpt-3.5-turbo-0301', () => bpeRanks)
Expand Down
1 change: 1 addition & 0 deletions src/model/gpt-3.5-turbo-0613.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import bpeRanks from '../bpeRanks/cl100k_base.js'
import { GptEncoding } from '../GptEncoding.js'

export * from '../constants.js'
export * from '../specialTokens.js'
// prettier-ignore
const api = GptEncoding.getEncodingApiForModel('gpt-3.5-turbo-0613', () => bpeRanks)
Expand Down
1 change: 1 addition & 0 deletions src/model/gpt-3.5-turbo-1106.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import bpeRanks from '../bpeRanks/cl100k_base.js'
import { GptEncoding } from '../GptEncoding.js'

export * from '../constants.js'
export * from '../specialTokens.js'
// prettier-ignore
const api = GptEncoding.getEncodingApiForModel('gpt-3.5-turbo-1106', () => bpeRanks)
Expand Down
1 change: 1 addition & 0 deletions src/model/gpt-3.5-turbo-16k-0613.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import bpeRanks from '../bpeRanks/cl100k_base.js'
import { GptEncoding } from '../GptEncoding.js'

export * from '../constants.js'
export * from '../specialTokens.js'
// prettier-ignore
const api = GptEncoding.getEncodingApiForModel('gpt-3.5-turbo-16k-0613', () => bpeRanks)
Expand Down
1 change: 1 addition & 0 deletions src/model/gpt-3.5-turbo-16k.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import bpeRanks from '../bpeRanks/cl100k_base.js'
import { GptEncoding } from '../GptEncoding.js'

export * from '../constants.js'
export * from '../specialTokens.js'
// prettier-ignore
const api = GptEncoding.getEncodingApiForModel('gpt-3.5-turbo-16k', () => bpeRanks)
Expand Down
1 change: 1 addition & 0 deletions src/model/gpt-3.5-turbo-finetune.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import bpeRanks from '../bpeRanks/cl100k_base.js'
import { GptEncoding } from '../GptEncoding.js'

export * from '../constants.js'
export * from '../specialTokens.js'
// prettier-ignore
const api = GptEncoding.getEncodingApiForModel('gpt-3.5-turbo-finetune', () => bpeRanks)
Expand Down
1 change: 1 addition & 0 deletions src/model/gpt-3.5-turbo.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import bpeRanks from '../bpeRanks/cl100k_base.js'
import { GptEncoding } from '../GptEncoding.js'

export * from '../constants.js'
export * from '../specialTokens.js'
// prettier-ignore
const api = GptEncoding.getEncodingApiForModel('gpt-3.5-turbo', () => bpeRanks)
Expand Down
1 change: 1 addition & 0 deletions src/model/gpt-4-0125-preview.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import bpeRanks from '../bpeRanks/cl100k_base.js'
import { GptEncoding } from '../GptEncoding.js'

export * from '../constants.js'
export * from '../specialTokens.js'
// prettier-ignore
const api = GptEncoding.getEncodingApiForModel('gpt-4-0125-preview', () => bpeRanks)
Expand Down
1 change: 1 addition & 0 deletions src/model/gpt-4-0314.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import bpeRanks from '../bpeRanks/cl100k_base.js'
import { GptEncoding } from '../GptEncoding.js'

export * from '../constants.js'
export * from '../specialTokens.js'
// prettier-ignore
const api = GptEncoding.getEncodingApiForModel('gpt-4-0314', () => bpeRanks)
Expand Down
1 change: 1 addition & 0 deletions src/model/gpt-4-0613.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import bpeRanks from '../bpeRanks/cl100k_base.js'
import { GptEncoding } from '../GptEncoding.js'

export * from '../constants.js'
export * from '../specialTokens.js'
// prettier-ignore
const api = GptEncoding.getEncodingApiForModel('gpt-4-0613', () => bpeRanks)
Expand Down
1 change: 1 addition & 0 deletions src/model/gpt-4-1106-preview.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import bpeRanks from '../bpeRanks/cl100k_base.js'
import { GptEncoding } from '../GptEncoding.js'

export * from '../constants.js'
export * from '../specialTokens.js'
// prettier-ignore
const api = GptEncoding.getEncodingApiForModel('gpt-4-1106-preview', () => bpeRanks)
Expand Down
1 change: 1 addition & 0 deletions src/model/gpt-4-1106-vision-preview.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import bpeRanks from '../bpeRanks/cl100k_base.js'
import { GptEncoding } from '../GptEncoding.js'

export * from '../constants.js'
export * from '../specialTokens.js'
// prettier-ignore
const api = GptEncoding.getEncodingApiForModel('gpt-4-1106-vision-preview', () => bpeRanks)
Expand Down
1 change: 1 addition & 0 deletions src/model/gpt-4-32k-0314.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import bpeRanks from '../bpeRanks/cl100k_base.js'
import { GptEncoding } from '../GptEncoding.js'

export * from '../constants.js'
export * from '../specialTokens.js'
// prettier-ignore
const api = GptEncoding.getEncodingApiForModel('gpt-4-32k-0314', () => bpeRanks)
Expand Down
Loading

0 comments on commit 3547826

Please sign in to comment.