Skip to content
This repository has been archived by the owner on Nov 16, 2023. It is now read-only.

Commit

Permalink
operator: Slice v10 (#122)
Browse files Browse the repository at this point in the history
  • Loading branch information
fs-eire authored Apr 11, 2019
1 parent e1fe3a9 commit a9f6889
Show file tree
Hide file tree
Showing 5 changed files with 133 additions and 56 deletions.
3 changes: 2 additions & 1 deletion lib/backends/cpu/op-resolve-rules.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ import {CpuMatMul} from './ops/matmul';
import {CpuAveragePool, CpuGlobalAveragePool, CpuGlobalMaxPool, CpuMaxPool} from './ops/pool';
import * as cpuReduce from './ops/reduce';
import {CpuReshape} from './ops/reshape';
import {CpuSlice} from './ops/slice';
import {CpuSlice, CpuSliceV10} from './ops/slice';
import {CpuSoftmax} from './ops/softmax';
import {CpuSqueeze} from './ops/squeeze';
import {CpuSum} from './ops/sum';
Expand Down Expand Up @@ -76,6 +76,7 @@ export const CPU_OP_RESOLVE_RULES: ReadonlyArray<OpSet.ResolveRule> = [
['Reshape', '', '5+', () => new CpuReshape()],
['Sigmoid', '', '6+', () => new unaryOps.CpuUnaryOp(NUMBER_TYPES, unaryOps.sigmoid)],
['Sin', '', '7+', () => new unaryOps.CpuUnaryOp(NUMBER_TYPES, unaryOps.sin)],
['Slice', '', '10+', () => new CpuSliceV10()], // TODO: support 'steps' for Slice-10
['Slice', '', '1-9', () => new CpuSlice()],
['Softmax', '', '1+', () => new CpuSoftmax()],
['Sqrt', '', '6+', () => new unaryOps.CpuUnaryOp(NUMBER_TYPES, unaryOps.sqrt)],
Expand Down
26 changes: 20 additions & 6 deletions lib/backends/cpu/ops/slice.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT license.

import {Slice} from '../../../ops/slice';
import {Slice, SliceV10} from '../../../ops/slice';
import {Tensor} from '../../../tensor';
import {ShapeUtil} from '../../../util';
import {CpuInferenceHandler} from '../inference-handler';
Expand All @@ -13,9 +13,23 @@ export class CpuSlice extends Slice {
}
}

export function slice(x: Tensor, starts: number[], ends: number[], axes: number[]): Tensor {
export class CpuSliceV10 extends SliceV10 {
run(inferenceHandler: CpuInferenceHandler, inputs: Tensor[]): Tensor[] {
if (inputs.length >= 5 && inputs[4].integerData.some((i: number) => i !== 1)) {
throw new Error(`currently non-1 steps is not supported for Slice`);
}
const starts = Array.from(inputs[1].integerData);
const ends = Array.from(inputs[2].integerData);
const axes = inputs.length >= 4 ? Array.from(inputs[3].integerData) : [];
const output = slice(inputs[0], starts, ends, axes);
return [output];
}
}

export function slice(
x: Tensor, starts: ReadonlyArray<number>, ends: ReadonlyArray<number>, axes: ReadonlyArray<number>): Tensor {
if (axes.length === 0) {
axes = x.dims.slice(0).map((val, ind) => ind);
axes = x.dims.map((val, ind) => ind);
}
axes = axes.map(axis => ShapeUtil.parseAxis(axis, x.dims.length));
starts = starts.map((start, ind) => {
Expand All @@ -32,7 +46,7 @@ export function slice(x: Tensor, starts: number[], ends: number[], axes: number[
});
const size: number[] = [];
const adjustedStarts: number[] = [];
axes.map((val, ind) => {
axes.forEach((val, ind) => {
size[val] = ends[ind] - starts[ind];
adjustedStarts[val] = starts[ind];
});
Expand All @@ -45,12 +59,12 @@ export function slice(x: Tensor, starts: number[], ends: number[], axes: number[
const oldDimsStride = ShapeUtil.computeStrides(x.dims ? x.dims : [x.data.length]);
const X = x.data;
const output = new Tensor(size, x.type);
const Y = output.numberData;
const Y = output.data;
for (let i = 0; i < Y.length; ++i) {
const newLogicalIndex = ShapeUtil.offsetToIndices(i, newDimsStride);
const oldLogicalIndex = newLogicalIndex.map((idx, j) => idx + adjustedStarts[j]);
const oldOffset = ShapeUtil.indicesToOffset(oldLogicalIndex, oldDimsStride);
Y[i] = X[oldOffset] as number;
Y[i] = X[oldOffset];
}
return output;
}
3 changes: 2 additions & 1 deletion lib/backends/webgl/op-resolve-rules.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import {WebGLPad} from './ops/pad';
import {WebGLAveragePool, WebGLGlobalAveragePool, WebGLGlobalMaxPool, WebGLMaxPool} from './ops/pool';
import * as reduceOps from './ops/reduce';
import {WebGLReshape} from './ops/reshape';
import {WebGLSlice} from './ops/slice';
import {WebGLSlice, WebGLSliceV10} from './ops/slice';
import {WebGLSoftmax} from './ops/softmax';
import {WebGLSplit} from './ops/split';
import {WebGLSqueeze} from './ops/squeeze';
Expand Down Expand Up @@ -82,6 +82,7 @@ export const WEBGL_OP_RESOLVE_RULES: ReadonlyArray<OpSet.ResolveRule> = [
['Reshape', '', '5+', () => new WebGLReshape()],
['Sigmoid', '', '6+', () => new unaryOps.WebGLUnaryOp(FLOAT_TYPES, unaryOps.glslSigmoid())],
['Sin', '', '7+', () => new unaryOps.WebGLUnaryOp(FLOAT_TYPES, unaryOps.glslSin())],
['Slice', '', '10+', () => new WebGLSliceV10()], // TODO: support 'steps' for Slice-10
['Slice', '', '1-9', () => new WebGLSlice()],
['Softmax', '', '1+', () => new WebGLSoftmax()],
// 'Split' operator has an optional attribute 'split'
Expand Down
127 changes: 79 additions & 48 deletions lib/backends/webgl/ops/slice.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
// Licensed under the MIT license.

import {Slice} from '../../../ops/slice';
import {Slice, SliceV10} from '../../../ops/slice';
import {Tensor} from '../../../tensor';
import {ShapeUtil} from '../../../util';
import {WebGLInferenceHandler} from '../inference-handler';
Expand All @@ -15,58 +15,89 @@ export class WebGLSlice extends Slice implements WebGLOperator {
}

createProgramInfo(handler: WebGLInferenceHandler, inputs: Tensor[]): ProgramInfo {
const x = inputs[0];
let axes = this.axes;
let starts = this.starts;
let ends = this.ends;

if (axes.length === 0) {
axes = x.dims.slice(0).map((val, ind) => ind);
}
axes = axes.map(axis => ShapeUtil.parseAxis(axis, x.dims.length));
starts = starts.map((start, ind) => {
if (start > x.dims[axes[ind]] - 1) {
return x.dims[axes[ind]];
}
return ShapeUtil.parseAxis(start, x.dims[axes[ind]]);
});
ends = ends.map((end, ind) => {
if (end > x.dims[axes[ind]] - 1) {
return x.dims[axes[ind]];
}
return ShapeUtil.parseAxis(end, x.dims[axes[ind]]);
});
return createProgramInfo(handler, inputs[0], this.starts, this.ends, this.axes);
}
createRunData(handler: WebGLInferenceHandler, programInfo: ProgramInfo, inputs: Tensor[]): RunData {
return createRunData(handler, programInfo, inputs);
}
}

const outputShape = x.dims.slice();
export class WebGLSliceV10 extends SliceV10 implements WebGLOperator {
run(inferenceHandler: WebGLInferenceHandler, inputs: Tensor[]): Tensor[] {
return WebGLOperatorHelper.run(this, inferenceHandler, inputs);
}

const sliceOps: string[] = [];
for (let i = 0; i < axes.length; i++) {
outputShape[axes[i]] = ends[i] - starts[i];
if (starts[i] > 0) {
sliceOps.push(`outputIdx[${axes[i]}] += ${starts[i]};`);
} // else { sliceOps.push(`outputIdx[${axes[i]}] += 0;`); }
createProgramInfo(handler: WebGLInferenceHandler, inputs: Tensor[]): ProgramInfo {
if (!handler.session.isInitializer(inputs[1]) || !handler.session.isInitializer(inputs[2]) ||
(inputs.length >= 4 && !handler.session.isInitializer(inputs[3])) ||
(inputs.length >= 5 && !handler.session.isInitializer(inputs[4]))) {
throw new Error(`dynamic slice attributes are not allowed`);
}
if (inputs.length >= 5 && inputs[4].integerData.some((i: number) => i !== 1)) {
throw new Error(`currently non-1 steps is not supported for Slice`);
}
const starts = Array.from(inputs[1].integerData);
const ends = Array.from(inputs[2].integerData);
const axes = inputs.length >= 4 ? Array.from(inputs[3].integerData) : [];

const rank = outputShape.length;
const shaderSource = `
uniform sampler2D A;
float process(int outputIdx[${rank}]) {
${sliceOps.join('\n ')}
return _A(outputIdx);
}`;
return {
hasMain: false,
inputLayouts: inputs.map(t => handler.getOrCreateTextureLayout(t)),
outputLayout: handler.createBasicTextureLayout(outputShape),
shaderSource,
};
return createProgramInfo(handler, inputs[0], starts, ends, axes);
}

createRunData(handler: WebGLInferenceHandler, programInfo: ProgramInfo, inputs: Tensor[]): RunData {
const inputTDs = inputs.map((t, i) => handler.getOrCreate(t, programInfo.inputLayouts[i]));
return {
inputTextureDatas: inputTDs,
outputTextureData: handler.createTextureDataFromLayout(programInfo.outputLayout, inputTDs[0].dataType),
uniformData: {}
};
return createRunData(handler, programInfo, inputs);
}
}

function createProgramInfo(
handler: WebGLInferenceHandler, x: Tensor, starts: ReadonlyArray<number>, ends: ReadonlyArray<number>,
axes: ReadonlyArray<number>): ProgramInfo {
if (axes.length === 0) {
axes = x.dims.slice(0).map((val, ind) => ind);
}
axes = axes.map(axis => ShapeUtil.parseAxis(axis, x.dims.length));
starts = starts.map((start, ind) => {
if (start > x.dims[axes[ind]] - 1) {
return x.dims[axes[ind]];
}
return ShapeUtil.parseAxis(start, x.dims[axes[ind]]);
});
ends = ends.map((end, ind) => {
if (end > x.dims[axes[ind]] - 1) {
return x.dims[axes[ind]];
}
return ShapeUtil.parseAxis(end, x.dims[axes[ind]]);
});

const outputShape = x.dims.slice();

const sliceOps: string[] = [];
for (let i = 0; i < axes.length; i++) {
outputShape[axes[i]] = ends[i] - starts[i];
if (starts[i] > 0) {
sliceOps.push(`outputIdx[${axes[i]}] += ${starts[i]};`);
} // else { sliceOps.push(`outputIdx[${axes[i]}] += 0;`); }
}

const rank = outputShape.length;
const shaderSource = `
uniform sampler2D A;
float process(int outputIdx[${rank}]) {
${sliceOps.join('\n ')}
return _A(outputIdx);
}`;
return {
hasMain: false,
inputLayouts: [handler.getOrCreateTextureLayout(x)],
outputLayout: handler.createBasicTextureLayout(outputShape),
shaderSource,
};
}

function createRunData(handler: WebGLInferenceHandler, programInfo: ProgramInfo, inputs: Tensor[]): RunData {
const inputTDs = [handler.getOrCreate(inputs[0], programInfo.inputLayouts[0])];
return {
inputTextureDatas: inputTDs,
outputTextureData: handler.createTextureDataFromLayout(programInfo.outputLayout, inputTDs[0].dataType),
uniformData: {}
};
}
30 changes: 30 additions & 0 deletions lib/ops/slice.ts
Original file line number Diff line number Diff line change
Expand Up @@ -33,3 +33,33 @@ export abstract class Slice implements Operator {
protected ends: number[];
protected starts: number[];
}

export abstract class SliceV10 implements Operator {
abstract run(inferenceHandler: InferenceHandler, inputs: Tensor[]): Tensor[]|Promise<Tensor[]>;

initialize(attributes: Attribute): void {}

checkInputs(inputs: Tensor[]): boolean {
if (!inputs || inputs.length < 3 || inputs.length > 5) {
return false;
}
return this.checkInputTypes(inputs);
}

protected checkInputTypes(inputs: Tensor[]): boolean {
if (inputs[1].type !== 'int32' || inputs[1].dims.length !== 1) {
return false;
}
if (inputs[2].type !== 'int32' || inputs[2].dims.length !== 1) {
return false;
}
if (inputs.length >= 4 && (inputs[3].type !== 'int32' || inputs[3].dims.length !== 1)) {
return false;
}
if (inputs.length >= 5 && (inputs[4].type !== 'int32' || inputs[4].dims.length !== 1)) {
return false;
}

return true;
}
}

0 comments on commit a9f6889

Please sign in to comment.