Skip to content
This repository has been archived by the owner on Nov 16, 2023. It is now read-only.

Commit

Permalink
src: clean up operator resolve rules (#210)
Browse files Browse the repository at this point in the history
  • Loading branch information
fs-eire authored Aug 20, 2020
1 parent fcc8d77 commit 10a0b82
Show file tree
Hide file tree
Showing 22 changed files with 293 additions and 236 deletions.
303 changes: 164 additions & 139 deletions docs/operators.md

Large diffs are not rendered by default.

11 changes: 6 additions & 5 deletions lib/backends/cpu/op-resolve-rules.ts
Original file line number Diff line number Diff line change
Expand Up @@ -38,15 +38,15 @@ export const CPU_OP_RESOLVE_RULES: ReadonlyArray<OpSet.ResolveRule> = [
['Acosh', '', '9+', () => new CpuUnaryOp(FLOAT_TYPES, unaryOps.acosh)],
['Add', '', '7+', () => new CpuBinaryOp(NUMBER_TYPES, (e1, e2) => (e1 + e2))],
['And', '', '7+', () => new CpuBinaryOp(['bool'], (e1, e2) => (e1 && e2))],
['ArgMax', '', '1+', () => new CpuArgMax()],
['ArgMax', '', '1-11', () => new CpuArgMax()],
['Asin', '', '7+', () => new CpuUnaryOp(FLOAT_TYPES, unaryOps.asin)],
['Asinh', '', '9+', () => new CpuUnaryOp(FLOAT_TYPES, unaryOps.asinh)],
['Atan', '', '7+', () => new CpuUnaryOp(FLOAT_TYPES, unaryOps.atan)],
['Atanh', '', '9+', () => new CpuUnaryOp(FLOAT_TYPES, unaryOps.atanh)],
['AveragePool', '', '7+', () => new CpuAveragePool()], // TODO: support new attributes for AveragePool-10
['AveragePool', '', '7-10', () => new CpuAveragePool()], // TODO: support new attributes for AveragePool-10
['BatchNormalization', '', '7+', () => new CpuBatchNormalization()],
['Ceil', '', '6+', () => new CpuUnaryOp(FLOAT_TYPES, unaryOps.ceil)],
['Clip', '', '6+', () => new CpuUnaryOp(FLOAT_TYPES, unaryOps.clip, unaryOps.clipInitializer)],
['Clip', '', '6-10', () => new CpuUnaryOp(FLOAT_TYPES, unaryOps.clip, unaryOps.clipInitializer)],
['Concat', '', '4+', () => new CpuConcat()],
['Conv', '', '1+', () => new CpuConv()],
['Cos', '', '7+', () => new CpuUnaryOp(FLOAT_TYPES, unaryOps.cos)],
Expand All @@ -58,7 +58,8 @@ export const CPU_OP_RESOLVE_RULES: ReadonlyArray<OpSet.ResolveRule> = [
['Flatten', '', '1+', () => new CpuFlatten()],
['Floor', '', '6+', () => new CpuUnaryOp(FLOAT_TYPES, unaryOps.floor)],
['Gather', '', '1+', () => new CpuGather()],
['Gemm', '', '7+', () => new CpuGemm()],
['Gemm', '', '7-10', () => new CpuGemm(false)],
['Gemm', '', '11+', () => new CpuGemm(true)],
['GlobalAveragePool', '', '1+', () => new CpuGlobalAveragePool()],
['GlobalMaxPool', '', '1+', () => new CpuGlobalMaxPool()],
['ImageScaler', '', '1+', () => new CpuImageScaler()],
Expand All @@ -68,7 +69,7 @@ export const CPU_OP_RESOLVE_RULES: ReadonlyArray<OpSet.ResolveRule> = [
['Log', '', '6+', () => new CpuUnaryOp(FLOAT_TYPES, unaryOps.log)],
['LRN', '', '1+', () => new CpuLrn()],
['MatMul', '', '1+', () => new CpuMatMul()],
['MaxPool', '', '1+', () => new CpuMaxPool()], // TODO: support new attributes for MaxPool-8 and MaxPool-10
['MaxPool', '', '1-9', () => new CpuMaxPool()], // TODO: support new attributes for MaxPool-8 and MaxPool-10
['Mul', '', '7+', () => new CpuBinaryOp(NUMBER_TYPES, (e1, e2) => (e1 * e2))],
['Neg', '', '6+', () => new CpuUnaryOp(NUMBER_TYPES, unaryOps.neg)],
['Not', '', '1+', () => new CpuUnaryOp(['bool'], unaryOps.not, undefined, 'bool')],
Expand Down
2 changes: 1 addition & 1 deletion lib/backends/cpu/ops/argMax.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ export class CpuArgMax extends ArgMax {

export function argMax(x: Tensor, axis: number, keepdims: boolean): Tensor {
const rank = x.dims ? x.dims.length : 1;
axis = ShapeUtil.parseAxis(axis, rank);
axis = ShapeUtil.normalizeAxis(axis, rank);
const outputDims = ReduceUtil.calcReduceShape(x.dims, [axis], true);
const X = x.data;
const Y = new Int32Array(ShapeUtil.size(outputDims));
Expand Down
5 changes: 3 additions & 2 deletions lib/backends/cpu/ops/gather.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ export class CpuGather extends Gather {
}

export function gather(x: Tensor, indices: Tensor, axis: number): Tensor {
axis = ShapeUtil.parseAxis(axis, x.dims.length);
axis = ShapeUtil.normalizeAxis(axis, x.dims.length);
const dims = x.dims.slice();
const newDims = dims.slice();
const indicesData = indices.data;
Expand All @@ -24,7 +24,8 @@ export function gather(x: Tensor, indices: Tensor, axis: number): Tensor {
for (let i = 0; i < Y.length; ++i) {
const newLogicalIndex = ShapeUtil.offsetToIndices(i, newDimsStrides);
const oldLogicalIndex = newLogicalIndex.slice();
oldLogicalIndex[axis] = indicesData[newLogicalIndex[axis]] as number;
const idx = indicesData[newLogicalIndex[axis]] as number;
oldLogicalIndex[axis] = idx < 0 ? idx + dims[axis] : idx;
const oldOffset = ShapeUtil.indicesToOffset(oldLogicalIndex, dimsStrides);
Y[i] = X[oldOffset] as number;
}
Expand Down
10 changes: 6 additions & 4 deletions lib/backends/cpu/ops/gemm.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,18 +10,20 @@ import {matMul2d} from './matmul';

export class CpuGemm extends Gemm {
run(inferenceHandler: CpuInferenceHandler, inputs: Tensor[]): Tensor[] {
const output = gemm(inputs[0], inputs[1], inputs[2], this.alpha, this.beta, this.transA, this.transB);
const output = gemm(
inputs[0], inputs[1], this.alpha, this.beta, this.transA, this.transB,
inputs.length === 3 ? inputs[2] : undefined);
return [output];
}
}

export function gemm(a: Tensor, b: Tensor, c: Tensor, alpha: number, beta: number, transA: boolean, transB: boolean) {
const [M, N, K] = util.GemmUtil.getShapeOfGemmResult(a.dims, transA, b.dims, transB, c.dims);
export function gemm(a: Tensor, b: Tensor, alpha: number, beta: number, transA: boolean, transB: boolean, c?: Tensor) {
const [M, N, K] = util.GemmUtil.getShapeOfGemmResult(a.dims, transA, b.dims, transB, c?.dims);

// The result will always be of the shape [M,N]
const output = new Tensor([M, N], a.type);
// broadcast and assign value from C to output
if (util.BroadcastUtil.calc(output, c, (a, b) => b, true) !== output) {
if (c && util.BroadcastUtil.calc(output, c, (a, b) => b, true) !== output) {
throw new Error(`tensor C is not broadcastable to [M,N]`);
}

Expand Down
14 changes: 7 additions & 7 deletions lib/backends/cpu/ops/reduce.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,49 +8,49 @@ import {CpuInferenceHandler} from '../inference-handler';

export class CpuReduceSum extends ReduceBase {
run(inferenceHandler: CpuInferenceHandler, inputs: Tensor[]): Tensor[]|Promise<Tensor[]> {
const output = reduceSum(inputs[0], this.axes, this.keepDims);
const output = reduceSum(inputs[0], ShapeUtil.normalizeAxes(this.axes, inputs[0].dims.length), this.keepDims);
return [output];
}
}

export class CpuReduceSumSquare extends ReduceBase {
run(inferenceHandler: CpuInferenceHandler, inputs: Tensor[]): Tensor[] {
const output = reduceSumSquare(inputs[0], this.axes, this.keepDims);
const output = reduceSumSquare(inputs[0], ShapeUtil.normalizeAxes(this.axes, inputs[0].dims.length), this.keepDims);
return [output];
}
}

export class CpuReduceLogSum extends ReduceBase {
run(inferenceHandler: CpuInferenceHandler, inputs: Tensor[]): Tensor[] {
const output = reduceLogSum(inputs[0], this.axes, this.keepDims);
const output = reduceLogSum(inputs[0], ShapeUtil.normalizeAxes(this.axes, inputs[0].dims.length), this.keepDims);
return [output];
}
}

export class CpuReduceMax extends ReduceBase {
run(inferenceHandler: CpuInferenceHandler, inputs: Tensor[]): Tensor[] {
const output = reduceMax(inputs[0], this.axes, this.keepDims);
const output = reduceMax(inputs[0], ShapeUtil.normalizeAxes(this.axes, inputs[0].dims.length), this.keepDims);
return [output];
}
}

export class CpuReduceMin extends ReduceBase {
run(inferenceHandler: CpuInferenceHandler, inputs: Tensor[]): Tensor[] {
const output = reduceMin(inputs[0], this.axes, this.keepDims);
const output = reduceMin(inputs[0], ShapeUtil.normalizeAxes(this.axes, inputs[0].dims.length), this.keepDims);
return [output];
}
}

export class CpuReduceMean extends ReduceBase {
run(inferenceHandler: CpuInferenceHandler, inputs: Tensor[]): Tensor[] {
const output = reduceMean(inputs[0], this.axes, this.keepDims);
const output = reduceMean(inputs[0], ShapeUtil.normalizeAxes(this.axes, inputs[0].dims.length), this.keepDims);
return [output];
}
}

export class CpuReduceProd extends ReduceBase {
run(inferenceHandler: CpuInferenceHandler, inputs: Tensor[]): Tensor[] {
const output = reduceProd(inputs[0], this.axes, this.keepDims);
const output = reduceProd(inputs[0], ShapeUtil.normalizeAxes(this.axes, inputs[0].dims.length), this.keepDims);
return [output];
}
}
Expand Down
6 changes: 3 additions & 3 deletions lib/backends/cpu/ops/slice.ts
Original file line number Diff line number Diff line change
Expand Up @@ -31,18 +31,18 @@ export function slice(
if (axes.length === 0) {
axes = x.dims.map((val, ind) => ind);
}
axes = axes.map(axis => ShapeUtil.parseAxis(axis, x.dims.length));
axes = ShapeUtil.normalizeAxes(axes, x.dims.length);
starts = starts.map((start, ind) => {
if (start > x.dims[axes[ind]] - 1) {
return x.dims[axes[ind]];
}
return ShapeUtil.parseAxis(start, x.dims[axes[ind]]);
return ShapeUtil.normalizeAxis(start, x.dims[axes[ind]]);
});
ends = ends.map((end, ind) => {
if (end > x.dims[axes[ind]] - 1) {
return x.dims[axes[ind]];
}
return ShapeUtil.parseAxis(end, x.dims[axes[ind]]);
return ShapeUtil.normalizeAxis(end, x.dims[axes[ind]]);
});
const size: number[] = [];
const adjustedStarts: number[] = [];
Expand Down
6 changes: 3 additions & 3 deletions lib/backends/cpu/ops/softmax.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@ export function softmax(x: Tensor, axis: number): Tensor {
const inputDimensions = x.dims;
const inputRank = inputDimensions.length;

const axisCorrected = util.ShapeUtil.parseAxis(axis, inputRank);
const N = util.ShapeUtil.sizeToDimension(inputDimensions, axisCorrected);
const D = util.ShapeUtil.sizeFromDimension(inputDimensions, axisCorrected);
axis = util.ShapeUtil.normalizeAxis(axis, inputRank);
const N = util.ShapeUtil.sizeToDimension(inputDimensions, axis);
const D = util.ShapeUtil.sizeFromDimension(inputDimensions, axis);

const X = x.numberData;

Expand Down
9 changes: 5 additions & 4 deletions lib/backends/wasm/op-resolve-rules.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,18 @@ import {WasmSum} from './ops/sum';
export const WASM_OP_RESOLVE_RULES: ReadonlyArray<OpSet.ResolveRule> = [
['Add', '', '7+', () => new WasmBinaryOp(['float32'], 'Add')],
['And', '', '7+', () => new WasmBinaryOp(['bool'], 'And')],
['AveragePool', '', '7+', () => new WasmAveragePool()], // TODO: support new attributes for AveragePool-10
['AveragePool', '', '7-10', () => new WasmAveragePool()], // TODO: support new attributes for AveragePool-10
['BatchNormalization', '', '7+', () => new WasmBatchNormalization()],
['Clip', '', '6+', () => new WasmClip()],
['Clip', '', '6-10', () => new WasmClip()],
['Conv', '', '1+', () => new WasmConv()],
['Div', '', '7+', () => new WasmBinaryOp(['float32'], 'Div')],
['Gemm', '', '7+', () => new WasmGemm()],
['Gemm', '', '7-10', () => new WasmGemm(false)],
['Gemm', '', '11+', () => new WasmGemm(true)],
['GlobalAveragePool', '', '1+', () => new WasmGlobalAveragePool()],
['GlobalMaxPool', '', '1+', () => new WasmGlobalMaxPool()],
['InstanceNormalization', '', '6+', () => new WasmInstanceNormalization()],
['MatMul', '', '1+', () => new WasmMatMul()],
['MaxPool', '', '1+', () => new WasmMaxPool()], // TODO: support new attributes for MaxPool-8 and MaxPool-10
['MaxPool', '', '1-9', () => new WasmMaxPool()], // TODO: support new attributes for MaxPool-8 and MaxPool-10
['Mul', '', '7+', () => new WasmBinaryOp(['float32'], 'Mul')],
['Or', '', '7+', () => new WasmBinaryOp(['bool'], 'Or')],
['PRelu', '', '7+', () => new WasmBinaryOp(['float32'], 'PRelu')],
Expand Down
4 changes: 2 additions & 2 deletions lib/backends/wasm/ops/gemm.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@ export class WasmGemm extends Gemm {
const b = inputs[1];
const c = inputs[2];

const [M, N] = GemmUtil.getShapeOfGemmResult(a.dims, this.transA, b.dims, this.transB, c.dims);
const [M, N] = GemmUtil.getShapeOfGemmResult(a.dims, this.transA, b.dims, this.transB, c?.dims);
const y = new Tensor([M, N], a.type);
if (!BroadcastUtil.calc(y, c, (a, b) => (b), true)) {
if (c && !BroadcastUtil.calc(y, c, (a, b) => (b), true)) {
throw new Error(`c is not broadcastable to the shape of the result of the Gemm operator`);
}
WasmBinding.getInstance().ccall(
Expand Down
6 changes: 3 additions & 3 deletions lib/backends/wasm/ops/softmax.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@ import {WasmInferenceHandler} from '../inference-handler';
export class WasmSoftmax extends Softmax {
run(inferenceHandler: WasmInferenceHandler, inputs: Tensor[]): Tensor[] {
const x = inputs[0];
const axisCorrected = ShapeUtil.parseAxis(this.axis, x.dims.length);
const N = ShapeUtil.sizeToDimension(x.dims, axisCorrected);
const D = ShapeUtil.sizeFromDimension(x.dims, axisCorrected);
const axis = ShapeUtil.normalizeAxis(this.axis, x.dims.length);
const N = ShapeUtil.sizeToDimension(x.dims, axis);
const D = ShapeUtil.sizeFromDimension(x.dims, axis);
const y = new Tensor(x.dims, x.type);
WasmBinding.getInstance().ccall(
'_softmax_f32', [x.floatData, 'float32ptr'], [y.floatData, 'float32ptr', 'out'], [N, 'int32'], [D, 'int32']);
Expand Down
9 changes: 5 additions & 4 deletions lib/backends/webgl/op-resolve-rules.ts
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,10 @@ export const WEBGL_OP_RESOLVE_RULES: ReadonlyArray<OpSet.ResolveRule> = [
['And', '', '7+', () => new binaryOps.WebGLBinaryOp(['bool'], binaryOps.glslAnd())],
['Asin', '', '7+', () => new unaryOps.WebGLUnaryOp(FLOAT_TYPES, unaryOps.glslAsin())],
['Atan', '', '7+', () => new unaryOps.WebGLUnaryOp(FLOAT_TYPES, unaryOps.glslAtan())],
['AveragePool', '', '7+', () => new WebGLAveragePool()], // TODO: support new attributes for AveragePool-10
['AveragePool', '', '7-10', () => new WebGLAveragePool()], // TODO: support new attributes for AveragePool-10
['BatchNormalization', '', '7+', () => new WebGLBatchNormalization()],
['Ceil', '', '6+', () => new unaryOps.WebGLUnaryOp(FLOAT_TYPES, unaryOps.glslCeil())],
['Clip', '', '6+', () => new WebGLClip()],
['Clip', '', '6-10', () => new WebGLClip()],
['Concat', '', '4+', () => new WebGLConcat()],
['Conv', '', '1+', () => new WebGLConv()],
['Cos', '', '7+', () => new unaryOps.WebGLUnaryOp(FLOAT_TYPES, unaryOps.glslCos())],
Expand All @@ -55,7 +55,8 @@ export const WEBGL_OP_RESOLVE_RULES: ReadonlyArray<OpSet.ResolveRule> = [
['Flatten', '', '1+', () => new WebGLFlatten()],
['Floor', '', '6+', () => new unaryOps.WebGLUnaryOp(FLOAT_TYPES, unaryOps.glslFloor())],
['Gather', '', '1+', () => new WebGLGather()],
['Gemm', '', '7+', () => new WebGLGemm()],
['Gemm', '', '7-10', () => new WebGLGemm(false)],
['Gemm', '', '11+', () => new WebGLGemm(true)],
['GlobalAveragePool', '', '1+', () => new WebGLGlobalAveragePool()],
['GlobalMaxPool', '', '1+', () => new WebGLGlobalMaxPool()],
['Greater', '', '7+', () => new binaryOps.WebGLBinaryOp(NUMBER_TYPES, binaryOps.glslGreater(), undefined, 'bool')],
Expand All @@ -66,7 +67,7 @@ export const WEBGL_OP_RESOLVE_RULES: ReadonlyArray<OpSet.ResolveRule> = [
['Less', '', '7+', () => new binaryOps.WebGLBinaryOp(NUMBER_TYPES, binaryOps.glslLess(), undefined, 'bool')],
['Log', '', '6+', () => new unaryOps.WebGLUnaryOp(FLOAT_TYPES, unaryOps.glslLog())],
['MatMul', '', '1+', () => new WebGLMatMul()],
['MaxPool', '', '1+', () => new WebGLMaxPool()], // TODO: support new attributes for MaxPool-8 and MaxPool-10
['MaxPool', '', '1-9', () => new WebGLMaxPool()], // TODO: support new attributes for MaxPool-8 and MaxPool-10
['Mul', '', '7+', () => new binaryOps.WebGLBinaryOp(NUMBER_TYPES, binaryOps.glslMul())],
['Neg', '', '6+', () => new unaryOps.WebGLUnaryOp(NUMBER_TYPES, unaryOps.glslNeg())],
['Not', '', '1+', () => new unaryOps.WebGLUnaryOp(['bool'], unaryOps.glslNot())],
Expand Down
19 changes: 11 additions & 8 deletions lib/backends/webgl/ops/gather.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import {Gather} from '../../../ops/gather';
import {Tensor} from '../../../tensor';
import {ShapeUtil} from '../../../util';
import {WebGLInferenceHandler} from '../inference-handler';
import {ProgramInfo, RunData, WebGLOperator} from '../types';

Expand All @@ -19,22 +20,23 @@ export class WebGLGather extends Gather implements WebGLOperator {
throw Error('A scalar tensor output has not been supported');
}

const axis = ShapeUtil.normalizeAxis(this.axis, inputShape.length);
const indexCopyOps: string[] = [];
for (let i = 0; i < outputShape.length; i++) {
// outputShape is divided into three parts: A, B, C
// |0 this.axis| this.axis + indexDataShape.length| end|
// | A | B | C |
// |0 axis| axis + indexDataShape.length | end|
// | A | B | C |
//
// inputIdx: [A, inputs[1][B], C]
if (i < this.axis) { // A
if (i < axis) { // A
outputShape[i] = inputShape[i];
indexCopyOps.push(`inputIdx[${i}] = outputIdx[${i}];`);
} else {
if (i < this.axis + indexDataShape.length) { // B
outputShape[i] = indexDataShape[i - this.axis];
indexCopyOps.push(`indexDataIdx[${i - this.axis}] = outputIdx[${i}];`);
if (i < axis + indexDataShape.length) { // B
outputShape[i] = indexDataShape[i - axis];
indexCopyOps.push(`indexDataIdx[${i - axis}] = outputIdx[${i}];`);
} else { // C
outputShape[i] = inputShape[i - indexDataShape.length + 1]; // skip 1 for this.axis
outputShape[i] = inputShape[i - indexDataShape.length + 1]; // skip 1 for axis
indexCopyOps.push(`inputIdx[${i - indexDataShape.length + 1}] = outputIdx[${i}];`);
}
}
Expand All @@ -48,7 +50,8 @@ export class WebGLGather extends Gather implements WebGLOperator {
int inputIdx[${irank}];
int indexDataIdx[${iDrank}];
${indexCopyOps.join('\n ')}
inputIdx[${this.axis}] = int(_B(indexDataIdx));
int idx = int(_B(indexDataIdx));
inputIdx[${axis}] = idx < 0 ? idx + ${inputShape[axis]} : idx;
return _A(inputIdx);
}`;
return {
Expand Down
16 changes: 9 additions & 7 deletions lib/backends/webgl/ops/gemm.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@ export class WebGLGemm extends Gemm implements WebGLOperator {
createProgramInfo(inferenceHandler: WebGLInferenceHandler, inputs: Tensor[]): ProgramInfo {
const aShape = inputs[0].dims.slice();
const bShape = inputs[1].dims.slice();
const cShape = inputs[2].dims.slice();
const [M, N] = GemmUtil.getShapeOfGemmResult(aShape, this.transA, bShape, this.transB, cShape);
const [M, N] = GemmUtil.getShapeOfGemmResult(
aShape, this.transA, bShape, this.transB, inputs.length === 3 ? inputs[2].dims : undefined);
const oShape = [M, N];
if (!oShape) {
throw new Error('Can\'t use gemm on the given tensors');
Expand All @@ -35,16 +35,18 @@ export class WebGLGemm extends Gemm implements WebGLOperator {
line = `value += _A(a) * _B(b);`;
}
const rank = oShape.length;
const cRank = cShape.length;
const declareC = inputs.length === 3 ? `int c[${inputs[2].dims.length}];` : '';
const broadcastC = inputs.length === 3 ? `bcastIndices_C(indices, c);` : '';
const calculateC = inputs.length === 3 ? `value += beta * _C(c);` : '';
const shaderSource = `
float process(int indices[${rank}]) {
int a[${rank}];
int b[${rank}];
int c[${cRank}];
${declareC}
copyVec(indices, a);
copyVec(indices, b);
bcastIndices_C(indices, c);
${broadcastC}
float value = 0.0;
for (int k=0; k<${sharedDim}; ++k) {
Expand All @@ -54,14 +56,14 @@ export class WebGLGemm extends Gemm implements WebGLOperator {
}
value = value * alpha;
value += beta * _C(c);
${calculateC}
return value;
}`;
const inputLayouts = inputs.map(t => inferenceHandler.getOrCreateTextureLayout(t));
return {
inputLayouts,
outputLayout: inferenceHandler.createTextureLayoutFromShape(oShape),
samplers: ['A', 'B', 'C'],
samplers: inputs.length === 3 ? ['A', 'B', 'C'] : ['A', 'B'],
variables: [{name: 'alpha', type: 'float'}, {name: 'beta', type: 'float'}],
shaderSource,
};
Expand Down
Loading

0 comments on commit 10a0b82

Please sign in to comment.