Skip to content

Commit

Permalink
Only upload to GPU if it's above the cpu forwarding threshold
Browse files Browse the repository at this point in the history
  • Loading branch information
mattsoulanille committed Oct 6, 2023
1 parent 69c40a8 commit bf9c6fa
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 44 deletions.
3 changes: 3 additions & 0 deletions tfjs-core/src/io/http.ts
Original file line number Diff line number Diff line change
Expand Up @@ -197,9 +197,12 @@ export class HTTPRequest implements IOHandler {
private async loadStream(): Promise<ModelArtifacts> {
const modelJSON = await this.loadModelJSON();
const fetchURLs = await this.getWeightUrls(modelJSON.weightsManifest);
const weightSpecs = getWeightSpecs(modelJSON.weightsManifest);
const stream = () => streamWeights(fetchURLs, this.loadOptions);

return {
...modelJSON,
weightSpecs,
streamWeights: stream,
};
}
Expand Down
13 changes: 11 additions & 2 deletions tfjs-core/src/io/io_utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import {CompositeArrayBuffer} from './composite_array_buffer';
import {Tensor} from '../tensor';
import {backend} from '../globals';
import {DataId} from '../tensor_info';
import { env } from '../environment';

/** Number of bytes reserved for the length of the string. (32bit integer). */
const NUM_BYTES_STRING_LENGTH = 4;
Expand Down Expand Up @@ -398,12 +399,20 @@ export async function decodeWeightsStream(
data = await readToLength(reader, data, NUM_BYTES_STRING_LENGTH);
const byteLength = getWeightBytelength(spec, data);
data = await readToLength(reader, data, byteLength);
const weightTensor = decodeWeight(spec, data);

// Slice the tensor out
const tensorData = data.slice(0, byteLength);
data = data.slice(byteLength);

const weightTensor = decodeWeight(spec, tensorData);
tensors[spec.name] = weightTensor;

// TODO(mattsoulanille): Better way to call uploadToGPU.
const b = backend();
if ('uploadToGPU' in b) {
// TODO(mattsoulanille): Make this work for webgl too.
if ('uploadToGPU' in b &&
sizeFromShape(weightTensor.shape) >= (env()
.get('WEBGPU_CPU_HANDOFF_SIZE_THRESHOLD') as number)) {
(b.uploadToGPU as (dataId: DataId) => void)(weightTensor.dataId);
}
}
Expand Down
2 changes: 1 addition & 1 deletion tfjs-core/src/io/weights_loader.ts
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ export function streamWeights(fetchURLs: string[], loadOptions: LoadOptions): Re

let fetchIndex = 0;
let chunkReader: ReadableStreamDefaultReader<Uint8Array> | undefined;
loadOptions.onProgress(0);
loadOptions.onProgress?.(0);
return new ReadableStream<Uint8Array>({
pull: async (controller) => {
while (fetchIndex < fetchURLs.length) {
Expand Down
48 changes: 7 additions & 41 deletions tfjs/yarn.lock
Original file line number Diff line number Diff line change
Expand Up @@ -1976,14 +1976,6 @@
resolved "https://registry.yarnpkg.com/@types/minimatch/-/minimatch-3.0.4.tgz#f0ec25dbf2f0e4b18647313ac031134ca5b24b21"
integrity sha512-1z8k4wzFnNjVK/tlxvrWuK5WMt6mydWWP7+zvH5eFep4oj+UkrfiJTRtjCeBXNpwaA/FYqqtb4/QS4ianFpIRA==

"@types/node-fetch@^2.1.2":
version "2.6.4"
resolved "https://registry.yarnpkg.com/@types/node-fetch/-/node-fetch-2.6.4.tgz#1bc3a26de814f6bf466b25aeb1473fa1afe6a660"
integrity sha512-1ZX9fcN4Rvkvgv4E6PAY5WXUFWFcRWxZa3EW83UjycOB9ljJCedb2CupIP4RZMEwF/M3eTcCihbBRgwtGbg5Rg==
dependencies:
"@types/node" "*"
form-data "^3.0.0"

"@types/node@*", "@types/node@>=10.0.0":
version "18.11.9"
resolved "https://registry.yarnpkg.com/@types/node/-/node-18.11.9.tgz#02d013de7058cea16d36168ef2fc653464cfbad4"
Expand Down Expand Up @@ -2153,11 +2145,6 @@ async@^3.0.1:
resolved "https://registry.yarnpkg.com/async/-/async-3.2.0.tgz#b3a2685c5ebb641d3de02d161002c60fc9f85720"
integrity sha512-TR2mEZFVOj2pLStYxLht7TyfuRzaydfpxr3k9RpHIzMgw7A64dzsdqCxH1WJyQdoe8T10nDXd9wnEigmiuHIZw==

asynckit@^0.4.0:
version "0.4.0"
resolved "https://registry.yarnpkg.com/asynckit/-/asynckit-0.4.0.tgz#c79ed97f7f34cb8f2ba1bc9790bcc366474b4b79"
integrity sha512-Oei9OH4tRh0YqU3GxhX79dM/mwVgvbZJaSNaRk+bshkj0S5cfHcgYakreBjrHwatXKbz+IoIdYLxrKim2MjW0Q==

available-typed-arrays@^1.0.2:
version "1.0.2"
resolved "https://registry.yarnpkg.com/available-typed-arrays/-/available-typed-arrays-1.0.2.tgz#6b098ca9d8039079ee3f77f7b783c4480ba513f5"
Expand Down Expand Up @@ -2586,13 +2573,6 @@ combine-source-map@^0.8.0:
lodash.memoize "~3.0.3"
source-map "~0.5.3"

combined-stream@^1.0.8:
version "1.0.8"
resolved "https://registry.yarnpkg.com/combined-stream/-/combined-stream-1.0.8.tgz#c3d45a8b34fd730631a110a8a2520682b31d5a7f"
integrity sha512-FQN4MRfuJeHf7cBbBMJFXhKSDq+2kAArBlmRBvcvFE5BB1HZKXtSFASDhdlz9zOYwxh8lDdnvmMOe/+5cdoEdg==
dependencies:
delayed-stream "~1.0.0"

commander@^2.12.1, commander@^2.20.0:
version "2.20.3"
resolved "https://registry.yarnpkg.com/commander/-/commander-2.20.3.tgz#fd485e84c03eb4881c20722ba48035e8531aeb33"
Expand Down Expand Up @@ -2802,11 +2782,6 @@ define-properties@^1.1.3:
dependencies:
object-keys "^1.0.12"

delayed-stream@~1.0.0:
version "1.0.0"
resolved "https://registry.yarnpkg.com/delayed-stream/-/delayed-stream-1.0.0.tgz#df3ae199acadfb7d440aaae0b29e2272b24ec619"
integrity sha512-ZySD7Nf91aLB0RxL4KGrKHBXl7Eds1DAmEdcoVawXnLD7SDhpNgtuII2aAkg7a7QS41jxPSZ17p4VdGnMHk3MQ==

depd@~1.1.2:
version "1.1.2"
resolved "https://registry.yarnpkg.com/depd/-/depd-1.1.2.tgz#9bcd52e14c097763e749b274c4346ed2e560b5a9"
Expand Down Expand Up @@ -3089,15 +3064,6 @@ foreach@^2.0.5:
resolved "https://registry.yarnpkg.com/foreach/-/foreach-2.0.5.tgz#0bee005018aeb260d0a3af3ae658dd0136ec1b99"
integrity sha1-C+4AUBiusmDQo6865ljdATbsG5k=

form-data@^3.0.0:
version "3.0.1"
resolved "https://registry.yarnpkg.com/form-data/-/form-data-3.0.1.tgz#ebd53791b78356a99af9a300d4282c4d5eb9755f"
integrity sha512-RHkBKtLWUVwd7SqRIvCZMEvAMoGUp0XU+seQiZejj0COz3RI3hWP4sCv3gZWWLjJTd7rGwcsF5eKZGii0r/hbg==
dependencies:
asynckit "^0.4.0"
combined-stream "^1.0.8"
mime-types "^2.1.12"

from@~0:
version "0.1.7"
resolved "https://registry.yarnpkg.com/from/-/from-0.1.7.tgz#83c60afc58b9c56997007ed1a768b3ab303a44fe"
Expand Down Expand Up @@ -3927,20 +3893,20 @@ mime-db@1.52.0:
resolved "https://registry.yarnpkg.com/mime-db/-/mime-db-1.52.0.tgz#bbabcdc02859f4987301c856e3387ce5ec43bf70"
integrity sha512-sPU4uV7dYlvtWJxwwxHD0PuihVNiE7TyAbQ5SWxDCB9mUYvOgroQOwYQQOKPJ8CIbE+1ETVlOoK1UC2nU3gYvg==

mime-types@^2.1.12, mime-types@~2.1.34:
version "2.1.35"
resolved "https://registry.yarnpkg.com/mime-types/-/mime-types-2.1.35.tgz#381a871b62a734450660ae3deee44813f70d959a"
integrity sha512-ZDY+bPm5zTTF+YpCrAU9nK0UgICYPT0QtT1NZWFv4s++TNkcgVaT0g6+4R2uI4MjQjzysHB1zxuWL50hzaeXiw==
dependencies:
mime-db "1.52.0"

mime-types@~2.1.24:
version "2.1.34"
resolved "https://registry.yarnpkg.com/mime-types/-/mime-types-2.1.34.tgz#5a712f9ec1503511a945803640fafe09d3793c24"
integrity sha512-6cP692WwGIs9XXdOO4++N+7qjqv0rqxxVvJ3VHPh/Sc9mVZcQP+ZGhkKiTvWMQRr2tbHkJP/Yn7Y0npb3ZBs4A==
dependencies:
mime-db "1.51.0"

mime-types@~2.1.34:
version "2.1.35"
resolved "https://registry.yarnpkg.com/mime-types/-/mime-types-2.1.35.tgz#381a871b62a734450660ae3deee44813f70d959a"
integrity sha512-ZDY+bPm5zTTF+YpCrAU9nK0UgICYPT0QtT1NZWFv4s++TNkcgVaT0g6+4R2uI4MjQjzysHB1zxuWL50hzaeXiw==
dependencies:
mime-db "1.52.0"

mime@^2.5.2:
version "2.6.0"
resolved "https://registry.yarnpkg.com/mime/-/mime-2.6.0.tgz#a2a682a95cd4d0cb1d6257e28f83da7e35800367"
Expand Down

0 comments on commit bf9c6fa

Please sign in to comment.