Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

save ETag metadata, add allowOffline #90

Merged
merged 2 commits into from
Jul 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions examples/advanced/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -49,11 +49,13 @@
logger: LoggerWithoutDebug,
});
// await wllama.cacheManager.clear();
console.log('Files in cache:', await wllama.cacheManager.list())

print(`Loading model ${MODEL}`);
timeStart();
await wllama.loadModelFromUrl(MODEL, {
n_ctx: 1024,
// allowOffline: true,
progressCallback: ({ loaded, total }) => console.log(`Downloading... ${Math.round(loaded/total*100)}%`),
});
print(`Loaded, take ${timeEnd()} ms`);
Expand Down
103 changes: 90 additions & 13 deletions src/cache-manager.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
import { isSafari, isSafariMobile } from './utils';

const PREFIX_METADATA = '__metadata__';

// To prevent breaking change, we fill etag with a pre-defined value
export const POLYFILL_ETAG = 'polyfill_for_older_version';

export interface CacheEntry {
/**
* File name in OPFS, in the format: `${hashSHA1(fullURL)}_${fileName}`
Expand All @@ -9,6 +14,26 @@ export interface CacheEntry {
* Size of file (in bytes)
*/
size: number;
/**
* Other metadata
*/
metadata: CacheEntryMetadata;
};

export interface CacheEntryMetadata {
/**
* ETag header from remote request
* https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/ETag
*/
etag: string;
/**
* Remote file size (in bytes), used for integrity check
*/
originalSize: number;
/**
* Original URL of the remote model. Unused for now
*/
originalURL: string;
};

/**
Expand All @@ -21,18 +46,19 @@ export const CacheManager = {
* Format of the file name: `${hashSHA1(fullURL)}_${fileName}`
*/
async getNameFromURL(url: string): Promise<string> {
return await toFileName(url);
return await toFileName(url, '');
},

/**
* Write a new file to cache. This will overwrite existing file.
*
* @param name The file name returned by `getNameFromURL()` or `list()`
*/
async write(name: string, stream: ReadableStream): Promise<void> {
async write(name: string, stream: ReadableStream, metadata: CacheEntryMetadata): Promise<void> {
CacheManager._writeMetadata(name, metadata); // no need await
return await opfsWrite(name, stream);
},

/**
* Open a file in cache for reading
*
Expand All @@ -44,7 +70,9 @@ export const CacheManager = {
},

/**
* Get the size of a file in cache
* Get the size of a file in stored cache
*
* NOTE: in case the download is stopped mid-way (i.e. user close browser tab), the file maybe corrupted, size maybe different from `metadata.originalSize`
*
* @param name The file name returned by `getNameFromURL()` or `list()`
* @returns number of bytes, or -1 if file does not exist
Expand All @@ -53,18 +81,59 @@ export const CacheManager = {
return await opfsFileSize(name);
},

/**
* Get metadata of a cached file
*/
async getMetadata(name: string): Promise<CacheEntryMetadata | null> {
const stream = await opfsOpen(name, PREFIX_METADATA);
const cachedSize = await CacheManager.getSize(name);
if (!stream) {
return cachedSize > 0
// files created by older version of wllama doesn't have metadata, we will try to polyfill it
? {
etag: POLYFILL_ETAG,
originalSize: cachedSize,
originalURL: '',
}
// if cached file not found, we don't have metadata at all
: null;
}
try {
const meta = await new Response(stream).json();
return meta;
} catch (e) {
// worst case: metadata is somehow corrupted, we will re-download the model
return null;
}
},

/**
* List all files currently in cache
*/
async list(): Promise<CacheEntry[]> {
const cacheDir = await getCacheDir();
const result: CacheEntry[] = [];
const metadataMap: Record<string, CacheEntryMetadata> = {};
// @ts-ignore
for await (let [name, handler] of cacheDir.entries()) {
if (handler.kind === 'file' && name.startsWith(PREFIX_METADATA)) {
const stream = (await (handler as FileSystemFileHandle).getFile()).stream();
const meta = await new Response(stream).json().catch(_ => null);
metadataMap[name.replace(PREFIX_METADATA, '')] = meta;
}
}
// @ts-ignore
for await (let [name, handler] of cacheDir.entries()) {
if (handler.kind === 'file') {
if (handler.kind === 'file' && !name.startsWith(PREFIX_METADATA)) {
result.push({
name,
size: await (handler as FileSystemFileHandle).getFile().then(f => f.size),
metadata: metadataMap[name] || {
// try to polyfill for old versions
originalSize: (await (handler as FileSystemFileHandle).getFile()).size,
originalURL: '',
etag: '',
},
});
}
}
Expand Down Expand Up @@ -104,15 +173,23 @@ export const CacheManager = {
}
}
},

/**
* Internally used
*/
async _writeMetadata(name: string, metadata: CacheEntryMetadata): Promise<void> {
const blob = new Blob([JSON.stringify(metadata)], { type: 'text/plain' });
await opfsWrite(name, blob.stream(), PREFIX_METADATA);
},
};

/**
* Write to OPFS file from ReadableStream
*/
async function opfsWrite(key: string, stream: ReadableStream): Promise<void> {
async function opfsWrite(key: string, stream: ReadableStream, prefix = ''): Promise<void> {
try {
const cacheDir = await getCacheDir();
const fileName = await toFileName(key);
const fileName = await toFileName(key, prefix);
const writable = isSafari()
? await opfsWriteViaWorker(fileName)
: await cacheDir.getFileHandle(fileName, { create: true }).then(h => h.createWritable());
Expand All @@ -133,10 +210,10 @@ async function opfsWrite(key: string, stream: ReadableStream): Promise<void> {
* Opens a file in OPFS for reading
* @returns ReadableStream
*/
async function opfsOpen(key: string): Promise<ReadableStream | null> {
async function opfsOpen(key: string, prefix = ''): Promise<ReadableStream | null> {
try {
const cacheDir = await getCacheDir();
const fileName = await toFileName(key);
const fileName = await toFileName(key, prefix);
const fileHandler = await cacheDir.getFileHandle(fileName);
const file = await fileHandler.getFile();
return file.stream();
Expand All @@ -150,10 +227,10 @@ async function opfsOpen(key: string): Promise<ReadableStream | null> {
* Get file size of a file in OPFS
* @returns number of bytes, or -1 if file does not exist
*/
async function opfsFileSize(key: string): Promise<number> {
async function opfsFileSize(key: string, prefix = ''): Promise<number> {
try {
const cacheDir = await getCacheDir();
const fileName = await toFileName(key);
const fileName = await toFileName(key, prefix);
const fileHandler = await cacheDir.getFileHandle(fileName);
const file = await fileHandler.getFile();
return file.size;
Expand All @@ -163,11 +240,11 @@ async function opfsFileSize(key: string): Promise<number> {
}
}

async function toFileName(str: string) {
async function toFileName(str: string, prefix: string) {
const hashBuffer = await crypto.subtle.digest('SHA-1', new TextEncoder().encode(str));
const hashArray = Array.from(new Uint8Array(hashBuffer));
const hashHex = hashArray.map(b => b.toString(16).padStart(2, '0')).join('');
return `${hashHex}_${str.split('/').pop()}`;
return `${prefix}${hashHex}_${str.split('/').pop()}`;
}

async function getCacheDir() {
Expand Down
4 changes: 4 additions & 0 deletions src/downloader/multi-downloads.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,12 @@ export class MultiDownloads {
private logger: any;
private useCache: boolean;
private totalBytes: number = 0;
private allowOffline: boolean;

constructor(logger: any, urls: string[], maxParallel: number, opts: {
progressCallback?: ProgressCallback,
useCache: boolean,
allowOffline: boolean,
}) {
this.tasks = urls.map(url => {
// @ts-ignore
Expand All @@ -40,6 +42,7 @@ export class MultiDownloads {
this.maxParallel = maxParallel;
this.progressCallback = opts.progressCallback;
this.useCache = opts.useCache;
this.allowOffline = opts.allowOffline;
}

async run(): Promise<Blob[]> {
Expand All @@ -49,6 +52,7 @@ export class MultiDownloads {
logger: this.logger,
useCache: this.useCache,
startSignal: task.signalStart,
allowOffline: this.allowOffline,
progressCallback: ({ loaded }) => {
task.loaded = loaded;
this.updateProgress(task);
Expand Down
Loading