Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Stream weights to the GPU when loading a model #7994

Merged
merged 16 commits into from
Nov 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 27 additions & 3 deletions tfjs-converter/src/executor/graph_model.ts
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ import {OperationMapper} from '../operations/operation_mapper';

import {GraphExecutor} from './graph_executor';
import {ResourceManager} from './resource_manager';
// tslint:disable-next-line: no-imports-from-dist
import {decodeWeightsStream} from '@tensorflow/tfjs-core/dist/io/io_utils';

export const TFHUB_SEARCH_PARAM = '?tfjs-format=file';
export const DEFAULT_MODEL_NAME = 'model.json';
Expand Down Expand Up @@ -154,7 +156,12 @@ export class GraphModel<ModelURL extends Url = string | io.IOHandler> implements

const loadResult = this.handler.load() as ReturnType<IOHandler['load']>;
if (util.isPromise(loadResult)) {
return loadResult.then(artifacts => this.loadSync(artifacts)) as Result;
return loadResult.then(artifacts => {
if (artifacts.getWeightStream == null) {
return this.loadSync(artifacts);
}
return this.loadStreaming(artifacts);
}) as Result;
}

return this.loadSync(loadResult) as Result;
Expand All @@ -167,6 +174,25 @@ export class GraphModel<ModelURL extends Url = string | io.IOHandler> implements
* @doc {heading: 'Models', subheading: 'Classes', ignoreCI: true}
*/
loadSync(artifacts: io.ModelArtifacts) {
const weightMap = this.io.decodeWeights(
artifacts.weightData, artifacts.weightSpecs);

return this.loadWithWeightMap(artifacts, weightMap);
}

private async loadStreaming(artifacts: io.ModelArtifacts): Promise<boolean> {
if (artifacts.getWeightStream == null) {
throw new Error('Model artifacts missing streamWeights function');
}

const weightMap = await decodeWeightsStream(
artifacts.getWeightStream(), artifacts.weightSpecs);

return this.loadWithWeightMap(artifacts, weightMap);
}

private loadWithWeightMap(artifacts: io.ModelArtifacts,
weightMap: NamedTensorMap) {
this.artifacts = artifacts;
const graph = this.artifacts.modelTopology as tensorflow.IGraphDef;

Expand All @@ -184,8 +210,6 @@ export class GraphModel<ModelURL extends Url = string | io.IOHandler> implements
this.signature = signature;

this.version = `${graph.versions.producer}.${graph.versions.minConsumer}`;
const weightMap = this.io.decodeWeights(
this.artifacts.weightData, this.artifacts.weightSpecs);
this.executor = new GraphExecutor(
OperationMapper.Instance.transformGraph(graph, this.signature));
this.executor.weightMap = this.convertTensorMapToTensorsMap(weightMap);
Expand Down
38 changes: 36 additions & 2 deletions tfjs-converter/src/executor/graph_model_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ import {GraphNode} from '../operations/types';
import {GraphModel, loadGraphModel, loadGraphModelSync} from './graph_model';
import {HASH_TABLE_MODEL_V2} from './test_data/hash_table_v2_model_loader';
import {STRUCTURED_OUTPUTS_MODEL} from './test_data/structured_outputs_model_loader';
// tslint:disable-next-line: no-imports-from-dist
import {expectArrayBuffersEqual} from '@tensorflow/tfjs-core/dist/test_util';

const HOST = 'http://example.org';
const MODEL_URL = `${HOST}/model.json`;
Expand Down Expand Up @@ -125,6 +127,24 @@ const SIMPLE_HTTP_MODEL_LOADER = {
}
};

const SIMPLE_STREAMING_MODEL_LOADER = {
load: async () => {
return {
modelTopology: SIMPLE_MODEL,
weightSpecs: weightsManifest,
getWeightStream: () => {
const data = bias.dataSync();
const blob = new Blob([data]);
return blob.stream();
},
format: 'tfjs-graph-model',
generatedBy: '1.15',
convertedBy: '1.3.1',
userDefinedMetadata: {signature: SIGNATURE}
};
}
};

const NO_INPUT_SIGNATURE_MODEL_LOADER = {
load: async () => {
return {
Expand Down Expand Up @@ -438,7 +458,7 @@ describe('loadGraphModel', () => {
});

it('Pass a fetchFunc', async () => {
const fetchFunc = () => {};
const fetchFunc = (() => {}) as unknown as typeof fetch;
spyIo.getLoadHandlers.and.returnValue([CUSTOM_HTTP_MODEL_LOADER]);
await loadGraphModel(MODEL_URL, {fetchFunc}, spyIo);
expect(spyIo.getLoadHandlers).toHaveBeenCalledWith(MODEL_URL, {fetchFunc});
Expand Down Expand Up @@ -594,7 +614,13 @@ describe('Model', () => {

describe('simple model', () => {
beforeEach(() => {
spyIo.getLoadHandlers.and.returnValue([SIMPLE_HTTP_MODEL_LOADER]);
spyIo.getLoadHandlers.and.callFake((_url: string|string[],
loadOptions?: io.LoadOptions) => {
if (loadOptions.streamWeights) {
return [SIMPLE_STREAMING_MODEL_LOADER];
}
return [SIMPLE_HTTP_MODEL_LOADER];
});
spyIo.browserHTTPRequest.and.returnValue(SIMPLE_HTTP_MODEL_LOADER);
});
it('load', async () => {
Expand Down Expand Up @@ -776,6 +802,14 @@ describe('Model', () => {
expect(model).toBeDefined();
});

it('should stream graph model weights', async () => {
const model = await loadGraphModel(MODEL_URL, {streamWeights: true},
spyIo);
expect(model).toBeDefined();
expectArrayBuffersEqual(model.weights['Const'][0].dataSync(),
bias.dataSync());
});

describe('InferenceModel interface', () => {
it('should expose inputs', async () => {
await model.load();
Expand Down
68 changes: 45 additions & 23 deletions tfjs-core/src/io/http.ts
Original file line number Diff line number Diff line change
Expand Up @@ -27,31 +27,30 @@ import {assert} from '../util';
import {getModelArtifactsForJSON, getModelArtifactsInfoForJSON, getModelJSONForModelArtifacts, getWeightSpecs} from './io_utils';
import {CompositeArrayBuffer} from './composite_array_buffer';
import {IORouter, IORouterRegistry} from './router_registry';
import {IOHandler, LoadOptions, ModelArtifacts, ModelJSON, OnProgressCallback, SaveResult, WeightData, WeightsManifestConfig, WeightsManifestEntry} from './types';
import {loadWeightsAsArrayBuffer} from './weights_loader';
import {IOHandler, LoadOptions, ModelArtifacts, ModelJSON, SaveResult, WeightData, WeightsManifestConfig, WeightsManifestEntry} from './types';
import {loadWeightsAsArrayBuffer, streamWeights} from './weights_loader';

const OCTET_STREAM_MIME_TYPE = 'application/octet-stream';
const JSON_TYPE = 'application/json';
export class HTTPRequest implements IOHandler {
protected readonly path: string;
protected readonly requestInit: RequestInit;

private readonly fetch: Function;
private readonly fetch: typeof fetch;
private readonly weightUrlConverter: (weightName: string) => Promise<string>;

readonly DEFAULT_METHOD = 'POST';

static readonly URL_SCHEME_REGEX = /^https?:\/\//;

private readonly weightPathPrefix: string;
private readonly onProgress: OnProgressCallback;
private readonly loadOptions: LoadOptions;

constructor(path: string, loadOptions?: LoadOptions) {
if (loadOptions == null) {
loadOptions = {};
}
this.weightPathPrefix = loadOptions.weightPathPrefix;
this.onProgress = loadOptions.onProgress;
this.weightUrlConverter = loadOptions.weightUrlConverter;

if (loadOptions.fetchFunc != null) {
Expand Down Expand Up @@ -84,6 +83,7 @@ export class HTTPRequest implements IOHandler {
'requestInit is expected to have no pre-existing body, but has one.');
}
this.requestInit = loadOptions.requestInit || {};
this.loadOptions = loadOptions;
}

async save(modelArtifacts: ModelArtifacts): Promise<SaveResult> {
Expand Down Expand Up @@ -135,15 +135,7 @@ export class HTTPRequest implements IOHandler {
}
}

/**
* Load model artifacts via HTTP request(s).
*
* See the documentation to `tf.io.http` for details on the saved
* artifacts.
*
* @returns The loaded model artifacts (if loading succeeds).
*/
async load(): Promise<ModelArtifacts> {
private async loadModelJSON(): Promise<ModelJSON> {
const modelConfigRequest = await this.fetch(this.path, this.requestInit);

if (!modelConfigRequest.ok) {
Expand Down Expand Up @@ -182,18 +174,45 @@ export class HTTPRequest implements IOHandler {
`topology or manifest for weights.`);
}

return modelJSON;
}

/**
* Load model artifacts via HTTP request(s).
*
* See the documentation to `tf.io.http` for details on the saved
* artifacts.
*
* @returns The loaded model artifacts (if loading succeeds).
*/
async load(): Promise<ModelArtifacts> {
if (this.loadOptions.streamWeights) {
return this.loadStream();
}
const modelJSON = await this.loadModelJSON();
return getModelArtifactsForJSON(
modelJSON, (weightsManifest) => this.loadWeights(weightsManifest));
}

private async loadWeights(weightsManifest: WeightsManifestConfig):
Promise<[WeightsManifestEntry[], WeightData]> {
private async loadStream(): Promise<ModelArtifacts> {
const modelJSON = await this.loadModelJSON();
const fetchURLs = await this.getWeightUrls(modelJSON.weightsManifest);
const weightSpecs = getWeightSpecs(modelJSON.weightsManifest);
const stream = () => streamWeights(fetchURLs, this.loadOptions);

return {
...modelJSON,
weightSpecs,
getWeightStream: stream,
};
}

private async getWeightUrls(weightsManifest: WeightsManifestConfig):
Promise<string[]> {
const weightPath = Array.isArray(this.path) ? this.path[1] : this.path;
const [prefix, suffix] = parseUrl(weightPath);
const pathPrefix = this.weightPathPrefix || prefix;

const weightSpecs = getWeightSpecs(weightsManifest);

const fetchURLs: string[] = [];
const urlPromises: Array<Promise<string>> = [];
for (const weightsGroup of weightsManifest) {
Expand All @@ -209,12 +228,15 @@ export class HTTPRequest implements IOHandler {
if (this.weightUrlConverter) {
fetchURLs.push(...await Promise.all(urlPromises));
}
return fetchURLs;
}

private async loadWeights(weightsManifest: WeightsManifestConfig):
Promise<[WeightsManifestEntry[], WeightData]> {
const fetchURLs = await this.getWeightUrls(weightsManifest);
const weightSpecs = getWeightSpecs(weightsManifest);

const buffers = await loadWeightsAsArrayBuffer(fetchURLs, {
requestInit: this.requestInit,
fetchFunc: this.fetch,
onProgress: this.onProgress
});
const buffers = await loadWeightsAsArrayBuffer(fetchURLs, this.loadOptions);
return [weightSpecs, buffers];
}
}
Expand Down
3 changes: 2 additions & 1 deletion tfjs-core/src/io/io.ts
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import './local_storage';

import {browserFiles} from './browser_files';
import {browserHTTPRequest, http, isHTTPScheme} from './http';
import {concatenateArrayBuffers, decodeWeights, encodeWeights, getModelArtifactsForJSON, getModelArtifactsForJSONSync, getModelArtifactsInfoForJSON, getWeightSpecs} from './io_utils';
import {concatenateArrayBuffers, decodeWeights, decodeWeightsStream, encodeWeights, getModelArtifactsForJSON, getModelArtifactsForJSONSync, getModelArtifactsInfoForJSON, getWeightSpecs} from './io_utils';
import {fromMemory, fromMemorySync, withSaveHandler, withSaveHandlerSync} from './passthrough';
import {getLoadHandlers, getSaveHandlers, registerLoadRouter, registerSaveRouter} from './router_registry';
import {IOHandler, IOHandlerSync, LoadHandler, LoadOptions, ModelArtifacts, ModelArtifactsInfo, ModelJSON, ModelStoreManager, OnProgressCallback, RequestDetails, SaveConfig, SaveHandler, SaveResult, TrainingConfig, WeightGroup, WeightsManifestConfig, WeightsManifestEntry, WeightData} from './types';
Expand All @@ -36,6 +36,7 @@ export {
CompositeArrayBuffer,
concatenateArrayBuffers,
decodeWeights,
decodeWeightsStream,
encodeWeights,
fromMemory,
fromMemorySync,
Expand Down
Loading