Skip to content

Commit

Permalink
add downloadModel function (#95)
Browse files Browse the repository at this point in the history
  • Loading branch information
ngxson authored Jul 31, 2024
1 parent d15748b commit 2dc146d
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 1 deletion.
4 changes: 4 additions & 0 deletions src/downloader/multi-downloads.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,13 @@ export class MultiDownloads {
private useCache: boolean;
private totalBytes: number = 0;
private allowOffline: boolean;
private noTEE: boolean;

constructor(logger: any, urls: string[], maxParallel: number, opts: {
progressCallback?: ProgressCallback,
useCache: boolean,
allowOffline: boolean,
noTEE?: boolean,
}) {
this.tasks = urls.map(url => {
// @ts-ignore
Expand All @@ -43,6 +45,7 @@ export class MultiDownloads {
this.progressCallback = opts.progressCallback;
this.useCache = opts.useCache;
this.allowOffline = opts.allowOffline;
this.noTEE = !!opts.noTEE;
}

async run(): Promise<Blob[]> {
Expand All @@ -53,6 +56,7 @@ export class MultiDownloads {
useCache: this.useCache,
startSignal: task.signalStart,
allowOffline: this.allowOffline,
noTEE: this.noTEE,
progressCallback: ({ loaded }) => {
task.loaded = loaded;
this.updateProgress(task);
Expand Down
15 changes: 14 additions & 1 deletion src/downloader/remote-blob.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,11 @@ interface GGUFRemoteBlobCreateOptions {
progressCallback?: ProgressCallback;
startSignal?: Promise<void>;
allowOffline: boolean;
/**
* Should we skip TEE the output stream?
* Set to true if we only download model to cache, without reading it
*/
noTEE: boolean;
/**
* Custom debug logger
*/
Expand Down Expand Up @@ -80,6 +85,7 @@ export class GGUFRemoteBlob extends Blob {
cachedStream: cachedFile!,
progressCallback: () => {}, // unused
etag: remoteFile.etag,
noTEE: opts.noTEE,
});
} else {
if (remoteFile.originalSize !== cachedFileSize) {
Expand All @@ -93,6 +99,7 @@ export class GGUFRemoteBlob extends Blob {
progressCallback: opts?.progressCallback ?? (() => {}),
startSignal: opts?.startSignal,
etag: remoteFile.etag,
noTEE: opts.noTEE,
});
}
}
Expand All @@ -107,12 +114,14 @@ export class GGUFRemoteBlob extends Blob {
private cachedStream?: ReadableStream;
private progressCallback: ProgressCallback;
private startSignal?: Promise<void>;
private noTEE: boolean;

constructor(url: string, start: number, end: number, full: boolean, customFetch: typeof fetch, additionals: {
cachedStream?: ReadableStream,
progressCallback: ProgressCallback,
startSignal?: Promise<void>,
etag: string,
noTEE: boolean,
}) {
super([]);

Expand All @@ -130,6 +139,7 @@ export class GGUFRemoteBlob extends Blob {
this.progressCallback = additionals.progressCallback;
this.startSignal = additionals.startSignal;
this.etag = additionals.etag;
this.noTEE = additionals.noTEE;
}

override get size(): number {
Expand Down Expand Up @@ -161,7 +171,10 @@ export class GGUFRemoteBlob extends Blob {
let loaded = 0;
const stream = new TransformStream({
transform(chunk, controller) {
controller.enqueue(chunk);
// if noTEE is set, we discard the chunk
if (!self.noTEE) {
controller.enqueue(chunk);
}
loaded += chunk.byteLength;
self.progressCallback({
loaded,
Expand Down
28 changes: 28 additions & 0 deletions src/wllama.ts
Original file line number Diff line number Diff line change
Expand Up @@ -275,8 +275,36 @@ export class Wllama {
return paddedShardIds.map((current) => `${baseURL}-${current}-of-${total}.gguf`);
}

/**
* Download a model to cache, without loading it
* @param modelUrl URL or list of URLs (in the correct order)
* @param config
*/
async downloadModel(modelUrl: string | string[], config: DownloadModelConfig = {}): Promise<void> {
if (modelUrl.length === 0) {
throw new Error('modelUrl must be an URL or a list of URLs (in the correct order)');
}
if (config.useCache === false) {
throw new Error('useCache must not be false');
}
const multiDownloads = new MultiDownloads(
this.logger(),
this.parseModelUrl(modelUrl),
config.parallelDownloads ?? 3,
{
progressCallback: config.progressCallback,
useCache: true,
allowOffline: !!config.allowOffline,
noTEE: true,
}
);
await multiDownloads.run();
}

/**
* Load model from a given URL (or a list of URLs, in case the model is splitted into smaller files)
* - If the model already been downloaded (via `downloadModel()`), then we will use the cached model
* - Else, we download the model from internet
* @param modelUrl URL or list of URLs (in the correct order)
* @param config
*/
Expand Down

0 comments on commit 2dc146d

Please sign in to comment.