-
Notifications
You must be signed in to change notification settings - Fork 127
operator: Support Clip op for cpu, wasm, and webgl backends #107
Changes from all commits
04f9b3b
ea32bbe
33217ea
8cb4f95
836f709
9ddc762
ebf79fb
48cbbbb
c0d37e1
3410dd5
67db533
74b50b7
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
// Copyright (c) Microsoft Corporation. All rights reserved. | ||
// Licensed under the MIT license. | ||
|
||
import {Clip} from '../../../ops/clip'; | ||
import {Tensor} from '../../../tensor'; | ||
import {WasmBinding} from '../../../wasm-binding'; | ||
import {WasmInferenceHandler} from '../inference-handler'; | ||
|
||
export class WasmClip extends Clip { | ||
run(inferenceHandler: WasmInferenceHandler, inputs: Tensor[]): Tensor[] { | ||
const result = new Tensor(inputs[0].dims, inputs[0].type); | ||
const size = result.floatData.length; | ||
if (inputs[0].type === 'float32') { | ||
WasmBinding.getInstance().ccall( | ||
'_clip_f32', [inputs[0].floatData, 'float32ptr'], [result.floatData, 'float32ptr', 'out'], [size, 'int32'], | ||
[this.min, 'float32'], [this.max, 'float32']); | ||
} | ||
// Expand for differnt types supported for this specific kernel of Clip | ||
else { | ||
throw new Error(`Unsupported input type for Clip operator.`); | ||
} | ||
return [result]; | ||
} | ||
|
||
// overriding the checkInputTypes() in the base class because Wasm backend has special type limitations | ||
checkInputTypes(inputs: Tensor[]): boolean { | ||
// currently Wasm backend only supports 'float32' input type | ||
if (inputs[0].type !== 'float32') { | ||
return false; | ||
} | ||
|
||
return true; | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
// Copyright (c) Microsoft Corporation. All rights reserved. | ||
// Licensed under the MIT license. | ||
|
||
import {Clip} from '../../../ops/clip'; | ||
import {Tensor} from '../../../tensor'; | ||
import {WebGLInferenceHandler} from '../inference-handler'; | ||
import {ProgramInfo} from '../program-info'; | ||
import {RunData} from '../program-manager'; | ||
import {WebGLOperator} from '../webgl-operator'; | ||
import {WebGLOperatorHelper} from '../webgl-operator-utils'; | ||
|
||
export class WebGLClip extends Clip implements WebGLOperator { | ||
run(inferenceHandler: WebGLInferenceHandler, inputs: Tensor[]): Tensor[] { | ||
return WebGLOperatorHelper.run(this, inferenceHandler, inputs); | ||
} | ||
createProgramInfo(handler: WebGLInferenceHandler, inputs: Tensor[]): ProgramInfo { | ||
const outputShape = inputs[0].dims.slice(); | ||
const shaderSource = ` | ||
const float min = float(${this.min}); | ||
const float max = float(${this.max}); | ||
uniform sampler2D A; | ||
void main() { | ||
float v = texture2D(A, TexCoords).r; | ||
gl_FragColor = vec4(clamp(v, min, max)); | ||
} | ||
`; | ||
return { | ||
hasMain: true, | ||
inputLayouts: [handler.getOrCreateTextureLayout(inputs[0])], | ||
outputLayout: handler.createBasicTextureLayout(outputShape), | ||
shaderSource, | ||
}; | ||
} | ||
createRunData(handler: WebGLInferenceHandler, programInfo: ProgramInfo, inputs: Tensor[]): RunData { | ||
const inputTDs = [handler.getOrCreate(inputs[0], programInfo.inputLayouts[0])]; | ||
return { | ||
inputTextureDatas: inputTDs, | ||
outputTextureData: handler.createTextureDataFromLayout(programInfo.outputLayout, inputTDs[0].dataType), | ||
uniformData: {} | ||
}; | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
// Copyright (c) Microsoft Corporation. All rights reserved. | ||
// Licensed under the MIT license. | ||
|
||
import {Attribute} from '../attribute'; | ||
import {InferenceHandler} from '../backend'; | ||
import {Operator} from '../operators'; | ||
import {Tensor} from '../tensor'; | ||
|
||
export abstract class Clip implements Operator { | ||
abstract run(inferenceHandler: InferenceHandler, inputs: Tensor[]): Tensor[]|Promise<Tensor[]>; | ||
|
||
initialize(attributes: Attribute): void { | ||
this.min = attributes.getFloat('min', -3.4028234663852886e+38); | ||
this.max = attributes.getFloat('max', 3.4028234663852886e+38); | ||
} | ||
|
||
checkInputs(inputs: Tensor[]): boolean { | ||
if (!inputs || inputs.length !== 1) { | ||
return false; | ||
} | ||
|
||
return this.checkInputTypes(inputs); | ||
} | ||
|
||
protected checkInputTypes(inputs: Tensor[]): boolean { | ||
if (inputs[0].type !== 'float32' && inputs[0].type !== 'float64') { | ||
return false; | ||
} | ||
|
||
return true; | ||
} | ||
|
||
protected min: number; | ||
protected max: number; | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
// Copyright (c) Microsoft Corporation. All rights reserved. | ||
// Licensed under the MIT license. | ||
|
||
#include "clip.h" | ||
|
||
// Wasm interop methods | ||
void clip_f32(void *data) { | ||
uint32_t *dataIndex = static_cast<uint32_t *>(data); | ||
const float *input = PARAM_FLOAT_PTR(data, dataIndex[1]); | ||
float *output = PARAM_FLOAT_PTR(data, dataIndex[2]); | ||
const int32_t length = PARAM_INT32(data, dataIndex[3]); | ||
const float min = PARAM_FLOAT(data, dataIndex[4]); | ||
const float max = PARAM_FLOAT(data, dataIndex[5]); | ||
clip_imp<float>(input, output, length, min, max); | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
// Copyright (c) Microsoft Corporation. All rights reserved. | ||
// Licensed under the MIT license. | ||
|
||
#pragma once | ||
|
||
#include "common.h" | ||
|
||
extern "C" { | ||
void clip_f32(void *); | ||
// Expand for other supported data types for `clip` | ||
} | ||
|
||
// Core implementation of the op | ||
template <typename T> | ||
void clip_imp(const T *input, T *output, const int32_t length, const float min, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No use of templatizing like There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think it is OK for now. If we are going to have 10+ unary-op implementations it will be the time to consider this problem. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we can potentially support all unary-ops in WASM. Even then I would find it hard to templatize the core implementation as each has separate attributes. Since there is no native "attribute" object in our WASM layer, it becomes really hard to templatize this. |
||
const float max) { | ||
for (size_t i = 0; i < length; ++i) { | ||
const auto &val = input[i]; | ||
output[i] = (val < min) ? min : (val > max) ? max : val; | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No use of templatizing as
unary-op
in WASM backend as there is not much to templatize (no core implementation here). In fact, the overhead in dealing with types and attributes parsing (which is much more prevalent than inbinary-op
case) makes it an unattractive option. So, keep the kernels separate even forunary-op
like implementations