Skip to content

Commit

Permalink
Stream weights to the GPU when loading a model (#7994)
Browse files Browse the repository at this point in the history
When downloading model weight data, slice it into weight tensors and push them to the GPU eagerly. This avoids storing an extra copy of the weights on CPU, allowing for larger models (1.3B to possibly ~6.7B or larger) to be loaded without causing a V8 OOM crash.

When streaming the weights, check CPU_HANDOFF_SIZE_THRESHOLD or WEBGPU_CPU_HANDOFF_SIZE_THRESHOLD to determine whether the weight should be sent to GPU or remain on CPU.

This feature is guarded by the streamWeights option in LoadOptions. Since most of TFJS's graph model saving relies on the CPU copy of the model, model saving is disabled when the model was streamed (i.e. it will throw an error since the weights ArrayBuffer is missing).
  • Loading branch information
mattsoulanille authored Nov 28, 2023
1 parent 929b35d commit e2ba43c
Show file tree
Hide file tree
Showing 11 changed files with 532 additions and 284 deletions.
30 changes: 27 additions & 3 deletions tfjs-converter/src/executor/graph_model.ts
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ import {OperationMapper} from '../operations/operation_mapper';

import {GraphExecutor} from './graph_executor';
import {ResourceManager} from './resource_manager';
// tslint:disable-next-line: no-imports-from-dist
import {decodeWeightsStream} from '@tensorflow/tfjs-core/dist/io/io_utils';

export const TFHUB_SEARCH_PARAM = '?tfjs-format=file';
export const DEFAULT_MODEL_NAME = 'model.json';
Expand Down Expand Up @@ -154,7 +156,12 @@ export class GraphModel<ModelURL extends Url = string | io.IOHandler> implements

const loadResult = this.handler.load() as ReturnType<IOHandler['load']>;
if (util.isPromise(loadResult)) {
return loadResult.then(artifacts => this.loadSync(artifacts)) as Result;
return loadResult.then(artifacts => {
if (artifacts.getWeightStream == null) {
return this.loadSync(artifacts);
}
return this.loadStreaming(artifacts);
}) as Result;
}

return this.loadSync(loadResult) as Result;
Expand All @@ -167,6 +174,25 @@ export class GraphModel<ModelURL extends Url = string | io.IOHandler> implements
* @doc {heading: 'Models', subheading: 'Classes', ignoreCI: true}
*/
loadSync(artifacts: io.ModelArtifacts) {
const weightMap = this.io.decodeWeights(
artifacts.weightData, artifacts.weightSpecs);

return this.loadWithWeightMap(artifacts, weightMap);
}

private async loadStreaming(artifacts: io.ModelArtifacts): Promise<boolean> {
if (artifacts.getWeightStream == null) {
throw new Error('Model artifacts missing streamWeights function');
}

const weightMap = await decodeWeightsStream(
artifacts.getWeightStream(), artifacts.weightSpecs);

return this.loadWithWeightMap(artifacts, weightMap);
}

private loadWithWeightMap(artifacts: io.ModelArtifacts,
weightMap: NamedTensorMap) {
this.artifacts = artifacts;
const graph = this.artifacts.modelTopology as tensorflow.IGraphDef;

Expand All @@ -184,8 +210,6 @@ export class GraphModel<ModelURL extends Url = string | io.IOHandler> implements
this.signature = signature;

this.version = `${graph.versions.producer}.${graph.versions.minConsumer}`;
const weightMap = this.io.decodeWeights(
this.artifacts.weightData, this.artifacts.weightSpecs);
this.executor = new GraphExecutor(
OperationMapper.Instance.transformGraph(graph, this.signature));
this.executor.weightMap = this.convertTensorMapToTensorsMap(weightMap);
Expand Down
38 changes: 36 additions & 2 deletions tfjs-converter/src/executor/graph_model_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ import {GraphNode} from '../operations/types';
import {GraphModel, loadGraphModel, loadGraphModelSync} from './graph_model';
import {HASH_TABLE_MODEL_V2} from './test_data/hash_table_v2_model_loader';
import {STRUCTURED_OUTPUTS_MODEL} from './test_data/structured_outputs_model_loader';
// tslint:disable-next-line: no-imports-from-dist
import {expectArrayBuffersEqual} from '@tensorflow/tfjs-core/dist/test_util';

const HOST = 'http://example.org';
const MODEL_URL = `${HOST}/model.json`;
Expand Down Expand Up @@ -125,6 +127,24 @@ const SIMPLE_HTTP_MODEL_LOADER = {
}
};

const SIMPLE_STREAMING_MODEL_LOADER = {
load: async () => {
return {
modelTopology: SIMPLE_MODEL,
weightSpecs: weightsManifest,
getWeightStream: () => {
const data = bias.dataSync();
const blob = new Blob([data]);
return blob.stream();
},
format: 'tfjs-graph-model',
generatedBy: '1.15',
convertedBy: '1.3.1',
userDefinedMetadata: {signature: SIGNATURE}
};
}
};

const NO_INPUT_SIGNATURE_MODEL_LOADER = {
load: async () => {
return {
Expand Down Expand Up @@ -438,7 +458,7 @@ describe('loadGraphModel', () => {
});

it('Pass a fetchFunc', async () => {
const fetchFunc = () => {};
const fetchFunc = (() => {}) as unknown as typeof fetch;
spyIo.getLoadHandlers.and.returnValue([CUSTOM_HTTP_MODEL_LOADER]);
await loadGraphModel(MODEL_URL, {fetchFunc}, spyIo);
expect(spyIo.getLoadHandlers).toHaveBeenCalledWith(MODEL_URL, {fetchFunc});
Expand Down Expand Up @@ -594,7 +614,13 @@ describe('Model', () => {

describe('simple model', () => {
beforeEach(() => {
spyIo.getLoadHandlers.and.returnValue([SIMPLE_HTTP_MODEL_LOADER]);
spyIo.getLoadHandlers.and.callFake((_url: string|string[],
loadOptions?: io.LoadOptions) => {
if (loadOptions.streamWeights) {
return [SIMPLE_STREAMING_MODEL_LOADER];
}
return [SIMPLE_HTTP_MODEL_LOADER];
});
spyIo.browserHTTPRequest.and.returnValue(SIMPLE_HTTP_MODEL_LOADER);
});
it('load', async () => {
Expand Down Expand Up @@ -776,6 +802,14 @@ describe('Model', () => {
expect(model).toBeDefined();
});

it('should stream graph model weights', async () => {
const model = await loadGraphModel(MODEL_URL, {streamWeights: true},
spyIo);
expect(model).toBeDefined();
expectArrayBuffersEqual(model.weights['Const'][0].dataSync(),
bias.dataSync());
});

describe('InferenceModel interface', () => {
it('should expose inputs', async () => {
await model.load();
Expand Down
68 changes: 45 additions & 23 deletions tfjs-core/src/io/http.ts
Original file line number Diff line number Diff line change
Expand Up @@ -27,31 +27,30 @@ import {assert} from '../util';
import {getModelArtifactsForJSON, getModelArtifactsInfoForJSON, getModelJSONForModelArtifacts, getWeightSpecs} from './io_utils';
import {CompositeArrayBuffer} from './composite_array_buffer';
import {IORouter, IORouterRegistry} from './router_registry';
import {IOHandler, LoadOptions, ModelArtifacts, ModelJSON, OnProgressCallback, SaveResult, WeightData, WeightsManifestConfig, WeightsManifestEntry} from './types';
import {loadWeightsAsArrayBuffer} from './weights_loader';
import {IOHandler, LoadOptions, ModelArtifacts, ModelJSON, SaveResult, WeightData, WeightsManifestConfig, WeightsManifestEntry} from './types';
import {loadWeightsAsArrayBuffer, streamWeights} from './weights_loader';

const OCTET_STREAM_MIME_TYPE = 'application/octet-stream';
const JSON_TYPE = 'application/json';
export class HTTPRequest implements IOHandler {
protected readonly path: string;
protected readonly requestInit: RequestInit;

private readonly fetch: Function;
private readonly fetch: typeof fetch;
private readonly weightUrlConverter: (weightName: string) => Promise<string>;

readonly DEFAULT_METHOD = 'POST';

static readonly URL_SCHEME_REGEX = /^https?:\/\//;

private readonly weightPathPrefix: string;
private readonly onProgress: OnProgressCallback;
private readonly loadOptions: LoadOptions;

constructor(path: string, loadOptions?: LoadOptions) {
if (loadOptions == null) {
loadOptions = {};
}
this.weightPathPrefix = loadOptions.weightPathPrefix;
this.onProgress = loadOptions.onProgress;
this.weightUrlConverter = loadOptions.weightUrlConverter;

if (loadOptions.fetchFunc != null) {
Expand Down Expand Up @@ -84,6 +83,7 @@ export class HTTPRequest implements IOHandler {
'requestInit is expected to have no pre-existing body, but has one.');
}
this.requestInit = loadOptions.requestInit || {};
this.loadOptions = loadOptions;
}

async save(modelArtifacts: ModelArtifacts): Promise<SaveResult> {
Expand Down Expand Up @@ -135,15 +135,7 @@ export class HTTPRequest implements IOHandler {
}
}

/**
* Load model artifacts via HTTP request(s).
*
* See the documentation to `tf.io.http` for details on the saved
* artifacts.
*
* @returns The loaded model artifacts (if loading succeeds).
*/
async load(): Promise<ModelArtifacts> {
private async loadModelJSON(): Promise<ModelJSON> {
const modelConfigRequest = await this.fetch(this.path, this.requestInit);

if (!modelConfigRequest.ok) {
Expand Down Expand Up @@ -182,18 +174,45 @@ export class HTTPRequest implements IOHandler {
`topology or manifest for weights.`);
}

return modelJSON;
}

/**
* Load model artifacts via HTTP request(s).
*
* See the documentation to `tf.io.http` for details on the saved
* artifacts.
*
* @returns The loaded model artifacts (if loading succeeds).
*/
async load(): Promise<ModelArtifacts> {
if (this.loadOptions.streamWeights) {
return this.loadStream();
}
const modelJSON = await this.loadModelJSON();
return getModelArtifactsForJSON(
modelJSON, (weightsManifest) => this.loadWeights(weightsManifest));
}

private async loadWeights(weightsManifest: WeightsManifestConfig):
Promise<[WeightsManifestEntry[], WeightData]> {
private async loadStream(): Promise<ModelArtifacts> {
const modelJSON = await this.loadModelJSON();
const fetchURLs = await this.getWeightUrls(modelJSON.weightsManifest);
const weightSpecs = getWeightSpecs(modelJSON.weightsManifest);
const stream = () => streamWeights(fetchURLs, this.loadOptions);

return {
...modelJSON,
weightSpecs,
getWeightStream: stream,
};
}

private async getWeightUrls(weightsManifest: WeightsManifestConfig):
Promise<string[]> {
const weightPath = Array.isArray(this.path) ? this.path[1] : this.path;
const [prefix, suffix] = parseUrl(weightPath);
const pathPrefix = this.weightPathPrefix || prefix;

const weightSpecs = getWeightSpecs(weightsManifest);

const fetchURLs: string[] = [];
const urlPromises: Array<Promise<string>> = [];
for (const weightsGroup of weightsManifest) {
Expand All @@ -209,12 +228,15 @@ export class HTTPRequest implements IOHandler {
if (this.weightUrlConverter) {
fetchURLs.push(...await Promise.all(urlPromises));
}
return fetchURLs;
}

private async loadWeights(weightsManifest: WeightsManifestConfig):
Promise<[WeightsManifestEntry[], WeightData]> {
const fetchURLs = await this.getWeightUrls(weightsManifest);
const weightSpecs = getWeightSpecs(weightsManifest);

const buffers = await loadWeightsAsArrayBuffer(fetchURLs, {
requestInit: this.requestInit,
fetchFunc: this.fetch,
onProgress: this.onProgress
});
const buffers = await loadWeightsAsArrayBuffer(fetchURLs, this.loadOptions);
return [weightSpecs, buffers];
}
}
Expand Down
3 changes: 2 additions & 1 deletion tfjs-core/src/io/io.ts
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import './local_storage';

import {browserFiles} from './browser_files';
import {browserHTTPRequest, http, isHTTPScheme} from './http';
import {concatenateArrayBuffers, decodeWeights, encodeWeights, getModelArtifactsForJSON, getModelArtifactsForJSONSync, getModelArtifactsInfoForJSON, getWeightSpecs} from './io_utils';
import {concatenateArrayBuffers, decodeWeights, decodeWeightsStream, encodeWeights, getModelArtifactsForJSON, getModelArtifactsForJSONSync, getModelArtifactsInfoForJSON, getWeightSpecs} from './io_utils';
import {fromMemory, fromMemorySync, withSaveHandler, withSaveHandlerSync} from './passthrough';
import {getLoadHandlers, getSaveHandlers, registerLoadRouter, registerSaveRouter} from './router_registry';
import {IOHandler, IOHandlerSync, LoadHandler, LoadOptions, ModelArtifacts, ModelArtifactsInfo, ModelJSON, ModelStoreManager, OnProgressCallback, RequestDetails, SaveConfig, SaveHandler, SaveResult, TrainingConfig, WeightGroup, WeightsManifestConfig, WeightsManifestEntry, WeightData} from './types';
Expand All @@ -36,6 +36,7 @@ export {
CompositeArrayBuffer,
concatenateArrayBuffers,
decodeWeights,
decodeWeightsStream,
encodeWeights,
fromMemory,
fromMemorySync,
Expand Down
Loading

0 comments on commit e2ba43c

Please sign in to comment.