Skip to content
This repository has been archived by the owner on Nov 16, 2023. It is now read-only.

operator: InstanceNormalization operator for wasm and cpu backends #82

Merged
merged 18 commits into from
Feb 15, 2019
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions lib/backends/cpu/ops-resolve.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import {CpuDropout} from './ops/dropout';
import {CpuGather} from './ops/gather';
import {CpuGemm} from './ops/gemm';
import {CpuImageScaler} from './ops/image-scaler';
import {CpuInstanceNormalization} from './ops/instance-normalization';
import {CpuLrn} from './ops/lrn';
import {CpuMatMul} from './ops/matmul';
import {CpuAveragePool, CpuGlobalAveragePool, CpuGlobalMaxPool, CpuMaxPool} from './ops/pool';
Expand Down Expand Up @@ -119,6 +120,8 @@ function createOperator(node: Graph.Node, domain: string, version: number): Oper
return new CpuGlobalMaxPool();
case 'GlobalAveragePool':
return new CpuGlobalAveragePool();
case 'InstanceNormalization':
return new CpuInstanceNormalization();
case 'PRelu':
return new CpuBinaryOp(NUMBER_TYPES, (e1, e2) => (e1 >= 0 ? e1 : e1 * e2));
case 'Reshape':
Expand Down
65 changes: 65 additions & 0 deletions lib/backends/cpu/ops/instance-normalization.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT license.

import {InstanceNormalization} from '../../../ops/instance-normalization';
import {Tensor} from '../../../tensor';
import {CpuInferenceHandler} from '../inference-handler';

export class CpuInstanceNormalization extends InstanceNormalization {
run(inferenceHandler: CpuInferenceHandler, inputs: Tensor[]): Tensor[] {
const output = instanceNormalization(inputs[0], inputs[1], inputs[2], this.epsilon);
return [output];
}
}

export function instanceNormalization(x: Tensor, scale: Tensor, b: Tensor, epsilon: number) {
const inputDimensions = x.dims;
const N = inputDimensions[0];
const C = inputDimensions[1];

// calculate channel size (i.e.) data points per channel
let channelSize = 1;
for (let i = 2; i < inputDimensions.length; i++) {
channelSize *= inputDimensions[i];
}
const output = new Tensor(x.dims, x.type);

const X = x.floatData;
const Y = output.floatData;
const scaleData = scale.numberData;
const bData = b.numberData;

let temp: number;
hariharans29 marked this conversation as resolved.
Show resolved Hide resolved
let mean: number;
let variance: number;
let physicalOffset: number;
let iterEnd: number;
let currentChannel: number;

for (let nc = 0; nc < N * C; nc++) {
physicalOffset = nc * channelSize;
iterEnd = physicalOffset + channelSize;
currentChannel = nc % C;

// compute mean for this channel
temp = 0;
for (let i = physicalOffset; i < iterEnd; ++i) {
temp += X[i];
}
mean = temp / channelSize;

// compute variance for this channel
temp = 0;
for (let i = physicalOffset; i < iterEnd; ++i) {
temp += Math.pow(X[i] - mean, 2);
}
variance = temp / channelSize;

// compute normalized value for data in this channel
for (let i = physicalOffset; i < iterEnd; ++i) {
Y[i] = scaleData[currentChannel] * ((X[i] - mean) / Math.sqrt(variance + epsilon)) + bData[currentChannel];
}
}

return output;
}
51 changes: 51 additions & 0 deletions lib/backends/wasm/ops/instance-normalization.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT license.

import {InstanceNormalization} from '../../../ops/instance-normalization';
import {Tensor} from '../../../tensor';
import {WasmBinding} from '../../../wasm-binding';
import {WasmInferenceHandler} from '../inference-handler';

export class WasmInstanceNormalization extends InstanceNormalization {
run(inferenceHandler: WasmInferenceHandler, inputs: Tensor[]): Tensor[] {
const x = inputs[0];
const scale = inputs[1];
const b = inputs[2];

// calculate channel size (i.e.) data points per channel
let channelSize = 1;
for (let i = 2; i < x.dims.length; i++) {
channelSize *= x.dims[i];
}

// create output Tensor after determining output size
const y = new Tensor(x.dims, x.type);
WasmBinding.getInstance().ccall(
'_instance_normalization_f32', [x.floatData, 'float32ptr'], [y.floatData, 'float32ptr', 'out'],
[x.dims[0], 'int32'], [x.dims[1], 'int32'], [channelSize, 'int32'], [scale.floatData, 'float32ptr'],
[b.floatData, 'float32ptr'], [this.epsilon, 'float32']);

return [y];
}

// overriding the checkInputTypes() in the base class because Wasm backend has special type limitations
checkInputTypes(inputs: Tensor[]): boolean {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we apply the special type check only? I think typescript keyword super can help to call parent class method.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right now checkInputTypes() in the base class also does some other validation other than type validation and hence we need to copy over that snippet of code as well in this method as this method overrides completely the base method. We can make a change later to make it in such a way that only type check is overridden and other kinds of input validation is still common across base and derived classes (hence prevent code duplication). But there are one or two more ops that require this kind of change. So I prefer to do this separately in another change.

const X = inputs[0];
const scale = inputs[1];
const B = inputs[2];

// input should atleast have three dimensions - N,C,dim1,...,dimn
// other inputs need to be one dimensional
if (X.dims.length < 3 || scale.dims.length !== 1 || B.dims.length !== 1) {
return false;
}
if (scale.dims[0] !== X.dims[1] || B.dims[0] !== X.dims[1]) {
return false;
}
// currently Wasm backend only supports 'float32' input type
if (X.type !== 'float32' || scale.type !== 'float32' || B.type !== 'float32') {
return false;
}
return true;
}
}
3 changes: 3 additions & 0 deletions lib/backends/wasm/session-handler.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import {WasmBatchNormalization} from './ops/batch-normalization';
import {WasmBinaryOp} from './ops/binary-op';
import {WasmConv} from './ops/conv';
import {WasmGemm} from './ops/gemm';
import {WasmInstanceNormalization} from './ops/instance-normalization';
import {WasmMatMul} from './ops/matmul';
import {WasmAveragePool, WasmGlobalAveragePool, WasmGlobalMaxPool, WasmMaxPool} from './ops/pool';
import {WasmSoftmax} from './ops/softmax';
Expand Down Expand Up @@ -72,6 +73,8 @@ export class WasmSessionHandler implements SessionHandler {
return new WasmGlobalMaxPool();
case 'GlobalAveragePool':
return new WasmGlobalAveragePool();
case 'InstanceNormalization':
return new WasmInstanceNormalization();
case 'PRelu':
return new WasmBinaryOp(['float32'], 'PRelu');
default:
Expand Down
45 changes: 45 additions & 0 deletions lib/ops/instance-normalization.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT license.

import {Attribute} from '../attribute';
import {InferenceHandler} from '../backend';
import {Operator} from '../operators';
import {Tensor} from '../tensor';

export abstract class InstanceNormalization implements Operator {
abstract run(inferenceHandler: InferenceHandler, inputs: Tensor[]): Tensor[]|Promise<Tensor[]>;

initialize(attributes: Attribute): void {
this.epsilon = attributes.getFloat('epsilon', 1e-5);
}

checkInputs(inputs: Tensor[]): boolean {
if (!inputs || inputs.length !== 3) {
return false;
}

return this.checkInputTypes(inputs);
}

protected checkInputTypes(inputs: Tensor[]): boolean {
const X = inputs[0];
const scale = inputs[1];
const B = inputs[2];

// input should atleast have three dimensions - N,C,dim1,...,dimn
// other inputs can have only one dimensions
if (X.dims.length < 3 || scale.dims.length !== 1 || B.dims.length !== 1) {
return false;
}
if (scale.dims[0] !== X.dims[1] || B.dims[0] !== X.dims[1]) {
return false;
}
if ((X.type !== 'float32' && X.type !== 'float64') || (scale.type !== 'float32' && scale.type !== 'float64') ||
(B.type !== 'float32' && B.type !== 'float64')) {
return false;
}
return true;
}

protected epsilon: number;
}
1 change: 1 addition & 0 deletions src/wasm-build-config.json
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
"_gemm_f32",
"_matmul_f32",
"_batch_normalization_f32",
"_instance_normalization_f32",
"_sum_f32",
"_softmax_f32"
]
Expand Down
4 changes: 2 additions & 2 deletions src/wasm-ops/batch-normalization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ void batch_normalization_f32(void *data) {
}

// Core operator implementation
void batch_normalization_f32_imp(float *X, float *Y, int batch_size,
int num_channels, int channel_size,
void batch_normalization_f32_imp(float *X, float *Y, int32_t batch_size,
int32_t num_channels, int32_t channel_size,
float *scale, float *bias, float *mean,
float *variance, float epsilon) {
for (size_t nc = 0; nc < batch_size * num_channels; ++nc) {
Expand Down
56 changes: 56 additions & 0 deletions src/wasm-ops/instance-normalization.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT license.

#include "instance-normalization.h"
#include "common.h"
#include <math.h>

// Wasm interop method
void instance_normalization_f32(void *data) {
uint32_t *dataIndex = static_cast<uint32_t *>(data);
uint32_t const argc = dataIndex[0];
instance_normalization_f32_imp(
PARAM_FLOAT_PTR(data, dataIndex[1]), PARAM_FLOAT_PTR(data, dataIndex[2]),
PARAM_INT32(data, dataIndex[3]), PARAM_INT32(data, dataIndex[4]),
PARAM_INT32(data, dataIndex[5]), PARAM_FLOAT_PTR(data, dataIndex[6]),
PARAM_FLOAT_PTR(data, dataIndex[7]), PARAM_FLOAT(data, dataIndex[8]));
}

// Core operator implementation
void instance_normalization_f32_imp(float *X, float *Y, int32_t batch_size,
int32_t num_channels, int32_t channel_size,
float *scale, float *bias, float epsilon) {
float temp;
float mean;
float variance;
size_t physicalOffset;
size_t iterEnd;
size_t currentChannel;

for (size_t nc = 0; nc < batch_size * channel_size; nc++) {
physicalOffset = nc * channel_size;
iterEnd = physicalOffset + channel_size;
currentChannel = nc % num_channels;

// compute mean for this channel
temp = 0;
for (size_t i = physicalOffset; i < iterEnd; ++i) {
temp += X[i];
}
mean = temp / channel_size;

// compute variance for this channel
temp = 0;
for (size_t i = physicalOffset; i < iterEnd; ++i) {
temp += pow(X[i] - mean, 2);
}
variance = temp / channel_size;

// compute normalized value for data in this channel
for (size_t i = physicalOffset; i < iterEnd; ++i) {
Y[i] =
scale[currentChannel] * ((X[i] - mean) / sqrt(variance + epsilon)) +
bias[currentChannel];
}
}
}
12 changes: 12 additions & 0 deletions src/wasm-ops/instance-normalization.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT license.

#pragma once

#include <stdint.h>

extern "C" {
void instance_normalization_f32(void *);
void instance_normalization_f32_imp(float *, float *, int32_t, int32_t, int32_t,
float *, float *, float);
}
6 changes: 5 additions & 1 deletion test/unittest-whitelist.jsonc
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,8 @@
"test_globalmaxpool_precomputed",
"test_globalmaxpool",
"test_identity",
"test_instancenorm_epsilon",
"test_instancenorm_example",
"test_leakyrelu_default",
"test_leakyrelu_example",
"test_leakyrelu",
Expand Down Expand Up @@ -556,7 +558,9 @@
"test_globalaveragepool_precomputed",
"test_globalaveragepool",
"test_globalmaxpool_precomputed",
"test_globalmaxpool"
"test_globalmaxpool",
"test_instancenorm_epsilon",
"test_instancenorm_example"
],
"ops": [
// Check in op tests that have native Wasm implementations
Expand Down