Skip to content

Commit

Permalink
[WebGPU/JS] Added Pad operator support (#16928)
Browse files Browse the repository at this point in the history
### Description
<!-- Describe your changes. -->



### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
  • Loading branch information
xhcao authored Sep 14, 2023
1 parent e11849e commit 198d468
Show file tree
Hide file tree
Showing 7 changed files with 379 additions and 5 deletions.
1 change: 1 addition & 0 deletions js/web/docs/webgpu-operators.md
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ Do not modify directly.*
| Mul | ai.onnx(7-12,13,14+) | |
| Neg | ai.onnx(6-12,13+) | |
| Not | ai.onnx(1+) | |
| Pad | ai.onnx(2-10,11-12,13-17,18,19+) | |
| Pow | ai.onnx(7-11,12,13-14,15+) | |
| Reciprocal | ai.onnx(6-12,13+) | |
| ReduceL1 | ai.onnx(1-10,11-12,13-17,18+) | |
Expand Down
2 changes: 2 additions & 0 deletions js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import {gemm, parseGemmAttributes} from './ops/gemm';
import {instanceNorm, parseInstanceNormAttributes} from './ops/instance-norm';
import {layerNorm, parseLayerNormAttributes} from './ops/layer-norm';
import {matMul} from './ops/matmul';
import {pad, parsePadAttributes} from './ops/pad';
import * as pool from './ops/pool';
import {parseReduceAttributes, reduceL1, reduceL2, reduceLogSum, reduceLogSumExp, reduceMax, reduceMean, reduceMin, reduceProd, reduceSum, reduceSumSquare} from './ops/reduce';
import {parseResizeAttributes, resize} from './ops/resize';
Expand Down Expand Up @@ -80,6 +81,7 @@ export const WEBGPU_OP_RESOLVE_RULES: Map<string, OperatorImplementation> = new
['Mul', [binaryOps.mul]],
['Neg', [unaryOps.neg]],
['Not', [unaryOps.not]],
['Pad', [pad, parsePadAttributes]],
['Pow', [binaryOps.pow]],
['Reciprocal', [unaryOps.reciprocal]],
['ReduceMin', [reduceMin, parseReduceAttributes]],
Expand Down
252 changes: 252 additions & 0 deletions js/web/lib/wasm/jsep/webgpu/ops/pad.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,252 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

import {DataType} from '../../../wasm-common';
import {TensorView} from '../../tensor';
import {ShapeUtil} from '../../util';
import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key';
import {ComputeContext, GpuDataType, ProgramInfo, ProgramInfoLoader, ProgramMetadata} from '../types';

import {IndicesHelper, inputVariable, outputVariable, ShaderHelper} from './common';

export interface PadAttributes extends AttributeWithCacheKey {
// 0-constant, 1-reflect, 2-edge, 3-wrap
readonly mode: number;
readonly value: number;
readonly pads: number[];
}

const validateInputs = (inputs: readonly TensorView[]): void => {
if (!inputs || inputs.length < 1) {
throw new Error('Too few inputs');
}
if (inputs[0].dataType !== DataType.float) {
throw new Error('Input type must be float.');
}

if (inputs.length >= 2) {
let validPads = inputs[0].dims.length * 2 === inputs[1].dims[0];
if (inputs.length === 4) {
validPads = inputs[3].dims[0] * 2 === inputs[1].dims[0];
}
if (!validPads) {
throw new Error('The pads should be a 1D tensor of shape [2 * input_rank] or [2 * num_axes].');
}
}
};

const getPadConstant =
(output: IndicesHelper, outputDims: readonly number[], inputDims: readonly number[],
inputStrides: readonly number[], pads: number[], dataType: string, constantValue: number): string => {
const inputRank = inputDims.length;

let block = '';
for (let i = inputRank - 1; i >= 0; --i) {
block += `
k = i32(${output.indicesGet('indices', i)}) - ${pads[i]};
if (k < 0) {
break;
}
if (k >= ${inputDims[i]}) {
break;
}
offset += k * ${inputStrides[i]};
`;
}

return `
value = ${dataType}(${constantValue});
for (var i = 0; i < 1; i++) {
var offset = 0;
var k = 0;
${block}
value = x[offset];
}
`;
};

const getPadReflect =
(output: IndicesHelper, outputDims: readonly number[], inputDims: readonly number[],
inputStrides: readonly number[], pads: number[]): string => {
const inputRank = inputDims.length;

let block = '';
for (let i = inputRank - 1; i >= 0; --i) {
block += `
k = i32(${output.indicesGet('indices', i)}) - ${pads[i]};
if (k < 0) {
k = -k;
}
{
let _2n_1 = ${2 * (inputDims[i] - 1)};
k = k % _2n_1;
if(k >= ${inputDims[i]}) {
k = _2n_1 - k;
}
}
offset += k * ${inputStrides[i]};
`;
}

return `
var offset = 0;
var k = 0;
${block}
value = x[offset];
`;
};

const getPadEdge =
(output: IndicesHelper, outputDims: readonly number[], inputDims: readonly number[],
inputStrides: readonly number[], pads: number[]): string => {
const inputRank = inputDims.length;

let block = '';
for (let i = inputRank - 1; i >= 0; --i) {
block += `
k = i32(${output.indicesGet('indices', i)}) - ${pads[i]};
if (k < 0) {
k = 0;
}
if (k >= ${inputDims[i]}) {
k = ${inputDims[i] - 1};
}
offset += k * ${inputStrides[i]};
`;
}

return `
var offset = 0;
var k = 0;
${block}
value = x[offset];
`;
};

const getPadWrap =
(output: IndicesHelper, outputDims: readonly number[], inputDims: readonly number[],
inputStrides: readonly number[], pads: number[]): string => {
const inputRank = inputDims.length;

let block = '';
for (let i = inputRank - 1; i >= 0; --i) {
block += `
k = i32(${output.indicesGet('indices', i)}) - ${pads[i]};
if (k < 0) {
k += ${inputDims[i]};
}
if (k >= ${inputDims[i]}) {
k -= ${inputDims[i]};
}
offset += k * ${inputStrides[i]};
`;
}

return `
var offset = 0;
var k = 0;
${block}
value = x[offset];
`;
};

const getPadSnippet =
(output: IndicesHelper, outputDims: readonly number[], inputDims: readonly number[],
inputStrides: readonly number[], attributes: PadAttributes, dataType: string): string => {
switch (attributes.mode) {
case 0:
return getPadConstant(
output, outputDims, inputDims, inputStrides, attributes.pads, dataType, attributes.value);
case 1:
return getPadReflect(output, outputDims, inputDims, inputStrides, attributes.pads);
case 2:
return getPadEdge(output, outputDims, inputDims, inputStrides, attributes.pads);
case 3:
return getPadWrap(output, outputDims, inputDims, inputStrides, attributes.pads);
default:
throw new Error('Invalid mode');
}
};

const generatePadCode =
(shaderHelper: ShaderHelper, inputs: readonly TensorView[], attributes: PadAttributes, dataType: string):
string => {
const inputDims = inputs[0].dims;
const outputDims = ShapeUtil.padShape(inputDims.slice(), attributes.pads);
const outputSize = ShapeUtil.size(outputDims);
const inputStrides = ShapeUtil.computeStrides(inputDims);

const output = outputVariable('output', inputs[0].dataType, outputDims);
const input = inputVariable('x', inputs[0].dataType, inputDims);

const padSnippet = getPadSnippet(output, outputDims, inputDims, inputStrides, attributes, dataType);
const padCode = `
${shaderHelper.declareVariables(input, output)}
${output.impl()}
${shaderHelper.mainStart()}
${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize)}
let indices = ${output.offsetToIndices('global_idx')};
var value = ${dataType}(0);
${padSnippet}
output[global_idx] = value;
}`;
return padCode;
};

const createPadProgramInfo =
(inputs: readonly TensorView[], metadata: ProgramMetadata, attributes: PadAttributes): ProgramInfo => {
const outputShape = ShapeUtil.padShape(inputs[0].dims.slice(), attributes.pads);
return {
...metadata,
outputs: [{dims: outputShape, dataType: inputs[0].dataType, gpuDataType: GpuDataType.default}],
getShaderSource: shaderHelper => generatePadCode(shaderHelper, inputs, attributes, 'f32'),
dispatchGroup: () => ({x: Math.ceil(ShapeUtil.size(outputShape) / 64 /* workgroup size */)})
};
};

const createPadAttributesFromInputs = (inputs: readonly TensorView[], attributes: PadAttributes): PadAttributes => {
if (inputs.length > 1) {
const bigInt64Pads = inputs[1].getBigInt64Array();
const value = (inputs.length >= 3) ? inputs[2].getFloat32Array()[0] : 0.0;

const inputRank = inputs[0].dims.length;
const updatePads = new Int32Array(2 * inputRank).fill(0);
if (inputs.length >= 4) {
const axes = inputs[3].getBigInt64Array();
for (let i = 0; i < axes.length; i++) {
updatePads[Number(axes[i])] = Number(bigInt64Pads[i]);
updatePads[Number(axes[i]) + inputRank] = Number(bigInt64Pads[i + axes.length]);
}
} else {
bigInt64Pads.forEach((i, v) => updatePads[Number(i)] = (Number(v)));
}

const pads: number[] = [];
updatePads.forEach(v => pads.push(v));

return createAttributeWithCacheKey({mode: attributes.mode, value, pads});
} else {
return attributes;
}
};

const createPadProgramInfoLoader = (inputs: readonly TensorView[], attributes: PadAttributes): ProgramInfoLoader => {
const updatedAttributes = createPadAttributesFromInputs(inputs, attributes);
const metadata:
ProgramMetadata = {name: 'Pad', inputTypes: [GpuDataType.default], cacheHint: updatedAttributes.cacheKey};
return {...metadata, get: () => createPadProgramInfo(inputs, metadata, updatedAttributes)};
};

export const pad = (context: ComputeContext, attributes: PadAttributes): void => {
validateInputs(context.inputs);
context.compute(createPadProgramInfoLoader(context.inputs, attributes), {inputs: [0]});
};

export const parsePadAttributes = (attributes: Record<string, unknown>): PadAttributes => {
const mode = attributes.mode as number;
const value = attributes.value as number;
const pads = attributes.pads as number[];
return createAttributeWithCacheKey({mode, value, pads});
};
11 changes: 6 additions & 5 deletions js/web/test/suite-test-list.jsonc
Original file line number Diff line number Diff line change
Expand Up @@ -505,7 +505,7 @@
// // "test_dynamicquantizelinear_min_adjusted_expanded",
// // "test_dynamicquantizelinear_min_adjusted",
// // "test_dynamicquantizelinear",
// // "test_edge_pad",
"test_edge_pad",
// "test_einsum_batch_diagonal",
// "test_einsum_batch_matmul",
// "test_einsum_inner_prod",
Expand Down Expand Up @@ -965,7 +965,7 @@
"test_reduce_sum_square_keepdims_random",
"test_reduce_sum_square_negative_axes_keepdims_example",
"test_reduce_sum_square_negative_axes_keepdims_random",
// // "test_reflect_pad",
"test_reflect_pad",
"test_relu",
// "test_reshape_allowzero_reordered",
"test_reshape_extended_dims",
Expand Down Expand Up @@ -1308,7 +1308,8 @@
"test_unsqueeze_three_axes",
"test_unsqueeze_two_axes",
"test_unsqueeze_unsorted_axes",
"test_unsqueeze"
"test_unsqueeze",
"test_wrap_pad"
// "test_upsample_nearest",
// "test_where_example",
// "test_where_long_example",
Expand Down Expand Up @@ -1361,8 +1362,8 @@
"reduce-min.jsonc",
"relu.jsonc",
"gelu.jsonc",
//"pad.jsonc",
//"pad-big.jsonc",
"pad.jsonc",
"pad-big.jsonc",
"pow.jsonc",
"pow_int32.jsonc",
"pow-big-number.jsonc",
Expand Down
12 changes: 12 additions & 0 deletions onnxruntime/core/providers/js/js_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,12 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 6

class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 12, float, Einsum);

class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 2, 10, Pad);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, Pad);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, 17, Pad);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 18, 18, Pad);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 19, Pad);

std::unique_ptr<KernelRegistry> RegisterKernels() {
auto kernel_registry = std::make_unique<onnxruntime::KernelRegistry>();

Expand Down Expand Up @@ -577,6 +583,12 @@ std::unique_ptr<KernelRegistry> RegisterKernels() {

BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 12, float, Einsum)>,

BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 2, 10, Pad)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, Pad)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, 17, Pad)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 18, 18, Pad)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 19, Pad)>,

};

for (auto& function_table_entry : function_table) {
Expand Down
Loading

0 comments on commit 198d468

Please sign in to comment.