This repository has been archived by the owner on Nov 16, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 127
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
webgl: instance normalization (#200)
* webgl: instance normalization * add change to operators.md * fix shader for webgl 1.0
- Loading branch information
Showing
4 changed files
with
154 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,149 @@ | ||
// Copyright (c) Microsoft Corporation. All rights reserved. | ||
// Licensed under the MIT license. | ||
|
||
import {InstanceNormalization} from '../../../ops/instance-normalization'; | ||
import {Tensor} from '../../../tensor'; | ||
import {getGlsl} from '../glsl-source'; | ||
import {WebGLInferenceHandler} from '../inference-handler'; | ||
import {Artifact, ProgramInfo, RunData, TextureLayout} from '../types'; | ||
|
||
export class WebGLInstanceNormalization extends InstanceNormalization { | ||
run(inferenceHandler: WebGLInferenceHandler, inputs: Tensor[]): Tensor[] { | ||
if (!this.artifacts) { | ||
this.artifacts = []; | ||
const programInfos = this.createProgramInfos(inferenceHandler, inputs); | ||
programInfos.forEach((pi, i) => { | ||
const artifact = inferenceHandler.session.programManager.build(pi); | ||
this.artifacts.push(artifact); | ||
}); | ||
} | ||
|
||
const runDatas = this.createRunDatas(inferenceHandler, this.artifacts.map(a => a.programInfo), inputs); | ||
runDatas.forEach((v, i) => inferenceHandler.session.programManager.run(this.artifacts[i], v)); | ||
return [runDatas[1].outputTextureData.tensor]; | ||
} | ||
|
||
checkInputTypes(inputs: Tensor[]): boolean { | ||
if (!super.checkInputTypes(inputs)) { | ||
return false; | ||
} | ||
|
||
if (inputs[0].dims.length !== 4) { | ||
// currently webgl implementation only support 4-D input. | ||
return false; | ||
} | ||
|
||
return true; | ||
} | ||
|
||
createMeanAndVarianceProgramInfo(inferenceHandler: WebGLInferenceHandler, xLayout: TextureLayout): ProgramInfo { | ||
const xDims = xLayout.shape; | ||
const channel = xDims[1]; | ||
const channelSize = xDims[2] * xDims[3]; | ||
const outputShape = [xDims[0], channel]; | ||
const outputUnpackedShape = [xDims[0], channel * 4]; | ||
|
||
const shaderSource = ` | ||
vec4 process(int[2] indices) { | ||
vec4 v = vec4(0.0); | ||
int a[4]; | ||
a[0] = indices[0]; | ||
a[1] = indices[1]; | ||
float temp = 0.0; | ||
for(int a2=0; a2<${xDims[2]}; a2++) { | ||
a[2] = a2; | ||
for(int a3=0; a3<${xDims[3]}; a3++) { | ||
a[3] = a3; | ||
float x = _X(a); | ||
temp += x; | ||
} | ||
} | ||
float mean = temp / float(${channelSize}); | ||
temp = 0.0; | ||
for(int a2=0; a2<${xDims[2]}; a2++) { | ||
a[2] = a2; | ||
for(int a3=0; a3<${xDims[3]}; a3++) { | ||
a[3] = a3; | ||
float x = _X(a); | ||
temp += (x - mean) * (x - mean); | ||
} | ||
} | ||
v.r = mean; | ||
v.g = temp / float(${channelSize}); | ||
return v; | ||
}`; | ||
return { | ||
inputLayouts: [xLayout], | ||
outputLayout: inferenceHandler.createTextureLayoutFromShape(outputShape, 4, outputUnpackedShape), | ||
samplers: ['X'], | ||
shaderSource, | ||
}; | ||
} | ||
|
||
createComputOutputProgramInfo( | ||
inferenceHandler: WebGLInferenceHandler, xLayout: TextureLayout, scaleLayout: TextureLayout, | ||
bLayout: TextureLayout, meanAndVarianceLayout: TextureLayout): ProgramInfo { | ||
const glsl = getGlsl(inferenceHandler.session.backend.glContext.version); | ||
const shaderSource = ` | ||
vec4 getMeanAndVariance(int[2] mv) { | ||
int offset = indicesToOffset_MeanAndVariance(mv); | ||
vec2 coords = offsetToCoords(offset, ${meanAndVarianceLayout.width}, ${meanAndVarianceLayout.height}); | ||
return ${glsl.texture2D}(MeanAndVariance, coords); | ||
} | ||
float process(int[4] indices) { | ||
int mv[2]; | ||
mv[0] = indices[0]; | ||
mv[1] = indices[1]; | ||
vec4 mean_and_variance = getMeanAndVariance(mv); | ||
float mean = mean_and_variance.r; | ||
float variance = mean_and_variance.g; | ||
int sb[1]; | ||
sb[0] = indices[1]; | ||
float scale = _Scale(sb); | ||
float b = _B(sb); | ||
return scale * (_X(indices) - mean) / sqrt(variance + epsilon) + b; | ||
}`; | ||
return { | ||
inputLayouts: [xLayout, meanAndVarianceLayout, scaleLayout, bLayout], | ||
outputLayout: inferenceHandler.createTextureLayoutFromShape(xLayout.shape), | ||
samplers: ['X', 'MeanAndVariance', 'Scale', 'B'], | ||
variables: [{name: 'epsilon', type: 'float'}], | ||
shaderSource, | ||
}; | ||
} | ||
createProgramInfos(inferenceHandler: WebGLInferenceHandler, inputs: Tensor[]): ProgramInfo[] { | ||
const xLayout = inferenceHandler.getOrCreateTextureLayout(inputs[0]); | ||
const scaleLayout = inferenceHandler.getOrCreateTextureLayout(inputs[1]); | ||
const bLayout = inferenceHandler.getOrCreateTextureLayout(inputs[2]); | ||
const meanAndVarianceProgramInfo = this.createMeanAndVarianceProgramInfo(inferenceHandler, xLayout); | ||
const computeOutputProgramInfo = this.createComputOutputProgramInfo( | ||
inferenceHandler, xLayout, scaleLayout, bLayout, meanAndVarianceProgramInfo.outputLayout); | ||
|
||
const programInfos: ProgramInfo[] = [meanAndVarianceProgramInfo, computeOutputProgramInfo]; | ||
return programInfos; | ||
} | ||
createRunDatas(inferenceHandler: WebGLInferenceHandler, programInfos: ProgramInfo[], inputs: Tensor[]): RunData[] { | ||
const dataType = inputs[0].type; | ||
const inputTD = inferenceHandler.getOrCreateTextureData(inputs[0], programInfos[0].inputLayouts[0]); | ||
const scaleTD = inferenceHandler.getOrCreateTextureData(inputs[1], programInfos[1].inputLayouts[2]); | ||
const bTD = inferenceHandler.getOrCreateTextureData(inputs[2], programInfos[1].inputLayouts[3]); | ||
const runDatas: RunData[] = []; | ||
runDatas.push({ | ||
inputTextureDatas: [inputTD], | ||
outputTextureData: inferenceHandler.createTextureDataFromLayout(programInfos[0].outputLayout, dataType), | ||
uniformData: {} | ||
}); | ||
runDatas.push({ | ||
inputTextureDatas: [inputTD, runDatas[0].outputTextureData, scaleTD, bTD], | ||
outputTextureData: inferenceHandler.createTextureDataFromLayout(programInfos[1].outputLayout, dataType), | ||
uniformData: {'epsilon': this.epsilon} | ||
}); | ||
return runDatas; | ||
} | ||
protected artifacts: Artifact[]; | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters