Skip to content

Commit

Permalink
Improve cache API (#80)
Browse files Browse the repository at this point in the history
* improve cache API

* add check guard for functions

* update docs
  • Loading branch information
ngxson authored Jul 3, 2024
1 parent dfcb986 commit 6d0d208
Show file tree
Hide file tree
Showing 2 changed files with 82 additions and 12 deletions.
52 changes: 45 additions & 7 deletions src/cache-manager.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,27 +15,42 @@ export interface CacheEntry {
* Cache implementation using OPFS (Origin private file system)
*/
export const CacheManager = {
/**
* Convert a given URL into file name in cache.
*
* Format of the file name: `${hashSHA1(fullURL)}_${fileName}`
*/
async getNameFromURL(url: string): Promise<string> {
return await toFileName(url);
},

/**
* Write a new file to cache. This will overwrite existing file.
*
* @param name The file name returned by `getNameFromURL()` or `list()`
*/
async write(key: string, stream: ReadableStream): Promise<void> {
return await opfsWrite(key, stream);
async write(name: string, stream: ReadableStream): Promise<void> {
return await opfsWrite(name, stream);
},

/**
* Open a file in cache for reading
*
* @param name The file name returned by `getNameFromURL()` or `list()`
* @returns ReadableStream, or null if file does not exist
*/
async open(key: string): Promise<ReadableStream | null> {
return await opfsOpen(key);
async open(name: string): Promise<ReadableStream | null> {
return await opfsOpen(name);
},

/**
* Get the size of a file in cache
*
* @param name The file name returned by `getNameFromURL()` or `list()`
* @returns number of bytes, or -1 if file does not exist
*/
async getSize(key: string): Promise<number> {
return await opfsFileSize(key);
async getSize(name: string): Promise<number> {
return await opfsFileSize(name);
},

/**
Expand All @@ -60,10 +75,33 @@ export const CacheManager = {
* Clear all files currently in cache
*/
async clear(): Promise<void> {
await CacheManager.deleteMany(() => true);
},

/**
* Delete a single file in cache
*
* @param nameOrURL Can be either an URL or a name returned by `getNameFromURL()` or `list()`
*/
async delete(nameOrURL: string): Promise<void> {
const name2 = await CacheManager.getNameFromURL(nameOrURL);
await CacheManager.deleteMany((entry) => (
entry.name === nameOrURL || entry.name === name2
));
},

/**
* Delete multiple files in cache.
*
* @param predicate A predicate like `array.filter(item => boolean)`
*/
async deleteMany(predicate: (e: CacheEntry) => boolean): Promise<void> {
const cacheDir = await getCacheDir();
const list = await CacheManager.list();
for (const item of list) {
cacheDir.removeEntry(item.name);
if (predicate(item)) {
cacheDir.removeEntry(item.name);
}
}
},
};
Expand Down
42 changes: 37 additions & 5 deletions src/wllama.ts
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,19 @@ export class Wllama {
return this.config.logger ?? console;
}

private checkModelLoaded() {
if (!this.isModelLoaded()) {
throw new Error('loadModel() is not yet called');
}
}

/**
* Check if the model is loaded via `loadModel()`
*/
isModelLoaded(): boolean {
return !!this.proxy && !!this.metadata;
}

/**
* Get token ID associated to BOS (begin of sentence) token.
*
Expand Down Expand Up @@ -196,10 +209,8 @@ export class Wllama {
* @returns ModelMetadata
*/
getModelMetadata(): ModelMetadata {
if (!this.metadata) {
throw new Error('loadModel() is not yet called');
}
return this.metadata;
this.checkModelLoaded();
return this.metadata!;
}

/**
Expand All @@ -210,6 +221,7 @@ export class Wllama {
* @returns true if multi-thread is used.
*/
isMultithread(): boolean {
this.checkModelLoaded();
return this.useMultiThread;
}

Expand Down Expand Up @@ -366,6 +378,7 @@ export class Wllama {
skipBOS?: boolean,
skipEOS?: boolean,
} = {}): Promise<number[]> {
this.checkModelLoaded();
const opt = {
skipBOS: false,
skipEOS: false,
Expand All @@ -391,6 +404,7 @@ export class Wllama {
* @returns Output completion text (only the completion part)
*/
async createCompletion(prompt: string, options: ChatCompletionOptions): Promise<string> {
this.checkModelLoaded();
this.samplingConfig = options.sampling ?? {};
await this.samplingInit(this.samplingConfig);
await this.kvClear(); // TODO: maybe cache tokens?
Expand Down Expand Up @@ -438,6 +452,7 @@ export class Wllama {
* @param pastTokens In case re-initializing the ctx_sampling, you can re-import past tokens into the new context
*/
async samplingInit(config: SamplingConfig, pastTokens: number[] = []): Promise<void> {
this.checkModelLoaded();
this.samplingConfig = config;
const result = await this.proxy.wllamaAction('sampling_init', {
...config,
Expand All @@ -454,6 +469,7 @@ export class Wllama {
* @returns A list of Uint8Array. The nth element in the list associated to nth token in vocab
*/
async getVocab(): Promise<Uint8Array[]> {
this.checkModelLoaded();
const result = await this.proxy.wllamaAction('get_vocab', {});
return result.vocab.map((arr: number[]) => new Uint8Array(arr));
}
Expand All @@ -465,6 +481,7 @@ export class Wllama {
* @returns Token ID associated to the given piece. Returns -1 if cannot find the token.
*/
async lookupToken(piece: string): Promise<number> {
this.checkModelLoaded();
const result = await this.proxy.wllamaAction('lookup_token', { piece });
if (!result.success) {
return -1;
Expand All @@ -480,6 +497,7 @@ export class Wllama {
* @returns List of token ID
*/
async tokenize(text: string, special: boolean = true): Promise<number[]> {
this.checkModelLoaded();
const result = await this.proxy.wllamaAction('tokenize', special
? { text, special: true }
: { text }
Expand All @@ -493,6 +511,7 @@ export class Wllama {
* @returns Uint8Array, which maybe an unfinished unicode
*/
async detokenize(tokens: number[]): Promise<Uint8Array> {
this.checkModelLoaded();
const result = await this.proxy.wllamaAction('detokenize', { tokens });
return new Uint8Array(result.buffer);
}
Expand All @@ -506,6 +525,7 @@ export class Wllama {
async decode(tokens: number[], options: {
skipLogits?: boolean,
}): Promise<{ nPast: number }> {
this.checkModelLoaded();
if (this.useEmbeddings) {
throw new Error('embeddings is enabled. Use wllama.setOptions({ embeddings: false }) to disable it.');
}
Expand All @@ -528,6 +548,7 @@ export class Wllama {
* @returns the token ID and its detokenized value (which maybe an unfinished unicode)
*/
async samplingSample(): Promise<{ piece: Uint8Array, token: number }> {
this.checkModelLoaded();
const result = await this.proxy.wllamaAction('sampling_sample', {});
return {
piece: new Uint8Array(result.piece),
Expand All @@ -540,6 +561,7 @@ export class Wllama {
* @param tokens
*/
async samplingAccept(tokens: number[]): Promise<void> {
this.checkModelLoaded();
const result = await this.proxy.wllamaAction('sampling_accept', { tokens });
if (!result.success) {
throw new Error('samplingAccept unknown error');
Expand All @@ -551,6 +573,7 @@ export class Wllama {
* @param topK Get top K tokens having highest logits value. If topK == -1, we return all n_vocab logits, but this is not recommended because it's slow.
*/
async getLogits(topK: number = 40): Promise<{token: number, p: number}[]> {
this.checkModelLoaded();
const result = await this.proxy.wllamaAction('get_logits', { top_k: topK });
const logits = result.logits as number[][];
return logits.map(([token, p]) => ({ token, p }));
Expand All @@ -562,6 +585,7 @@ export class Wllama {
* @returns A list of number represents an embedding vector of N dimensions
*/
async embeddings(tokens: number[]): Promise<number[]> {
this.checkModelLoaded();
if (!this.useEmbeddings) {
throw new Error('embeddings is disabled. Use wllama.setOptions({ embeddings: true }) to enable it.');
}
Expand All @@ -582,6 +606,7 @@ export class Wllama {
* @param nDiscard
*/
async kvRemove(nKeep: number, nDiscard: number): Promise<void> {
this.checkModelLoaded();
const result = await this.proxy.wllamaAction('kv_remove', {
n_keep: nKeep,
n_discard: nDiscard,
Expand All @@ -595,6 +620,7 @@ export class Wllama {
* Clear all tokens in KV cache
*/
async kvClear(): Promise<void> {
this.checkModelLoaded();
const result = await this.proxy.wllamaAction('kv_clear', {});
if (!result.success) {
throw new Error('kvClear unknown error');
Expand All @@ -608,6 +634,7 @@ export class Wllama {
* @returns List of tokens saved to the file
*/
async sessionSave(filePath: string): Promise<{ tokens: number[] }> {
this.checkModelLoaded();
const result = await this.proxy.wllamaAction('session_save', { session_path: filePath });
return result;
}
Expand All @@ -619,6 +646,7 @@ export class Wllama {
*
*/
async sessionLoad(filePath: string): Promise<void> {
this.checkModelLoaded();
const result = await this.proxy.wllamaAction('session_load', { session_path: filePath });
if (result.error) {
throw new Error(result.error);
Expand All @@ -631,12 +659,15 @@ export class Wllama {
* Set options for underlaying llama_context
*/
async setOptions(opt: ContextOptions): Promise<void> {
this.checkModelLoaded();
await this.proxy.wllamaAction('set_options', opt);
this.useEmbeddings = opt.embeddings;
}

/**
* Unload the model and free all memory
* Unload the model and free all memory.
*
* Note: This function will NOT crash if model is not yet loaded
*/
async exit(): Promise<void> {
await this.proxy?.wllamaExit();
Expand All @@ -646,6 +677,7 @@ export class Wllama {
* get debug info
*/
async _getDebugInfo(): Promise<any> {
this.checkModelLoaded();
return await this.proxy.wllamaDebug();
}

Expand Down

0 comments on commit 6d0d208

Please sign in to comment.