Skip to content

Commit

Permalink
Do not check webgpu handoff if not using webgpu
Browse files Browse the repository at this point in the history
  • Loading branch information
mattsoulanille committed Nov 15, 2023
1 parent 5758ec0 commit 5c93720
Showing 1 changed file with 9 additions and 5 deletions.
14 changes: 9 additions & 5 deletions tfjs-core/src/io/io_utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import {Tensor} from '../tensor';
import {backend} from '../globals';
import {DataId} from '../tensor_info';
import { env } from '../environment';
import { getBackend } from '../globals';

/** Number of bytes reserved for the length of the string. (32bit integer). */
const NUM_BYTES_STRING_LENGTH = 4;
Expand Down Expand Up @@ -337,12 +338,15 @@ export async function decodeWeightsStream(
tensors[spec.name] = weightTensor;

// TODO(mattsoulanille): Better way to call uploadToGPU.
const b = backend();
// TODO(mattsoulanille): Make this work for webgl too.
if ('uploadToGPU' in b &&
sizeFromShape(weightTensor.shape) >= (env()
.get('WEBGPU_CPU_HANDOFF_SIZE_THRESHOLD') as number)) {
(b.uploadToGPU as (dataId: DataId) => void)(weightTensor.dataId);
if (getBackend() === 'webgpu') {
const b = backend();

if ('uploadToGPU' in b &&
sizeFromShape(weightTensor.shape) >= (env()
.get('WEBGPU_CPU_HANDOFF_SIZE_THRESHOLD') as number)) {
(b.uploadToGPU as (dataId: DataId) => void)(weightTensor.dataId);
}
}
}

Expand Down

0 comments on commit 5c93720

Please sign in to comment.