diff --git a/tfjs-layers/src/activations.ts b/tfjs-layers/src/activations.ts index 97c2c02f668..2edaa07f1be 100644 --- a/tfjs-layers/src/activations.ts +++ b/tfjs-layers/src/activations.ts @@ -227,6 +227,24 @@ export class Swish extends Activation { } serialization.registerClass(Swish); +/** + * Mish activation function + */ +export class Mish extends Activation { + /** @nocollapse */ + static readonly className = 'mish'; + /** + * Calculate the activation function. + * + * @param x Tensor. + * @returns a Tensor of the same shape as x + */ + apply(x: Tensor): Tensor { + return tidy(() => tfc.mul(x, tfc.tanh(tfc.softplus(x)))); + } +} +serialization.registerClass(Mish); + export function serializeActivation(activation: Activation): string { return activation.getClassName(); } diff --git a/tfjs-layers/src/activations_test.ts b/tfjs-layers/src/activations_test.ts index 69dd91c27eb..bc4d6289812 100644 --- a/tfjs-layers/src/activations_test.ts +++ b/tfjs-layers/src/activations_test.ts @@ -13,7 +13,7 @@ */ import {scalar, tensor1d, tensor2d, tensor3d} from '@tensorflow/tfjs-core'; -import {Elu, HardSigmoid, Linear, LogSoftmax, Relu, Relu6, Selu, Sigmoid, Softmax, Softplus, Softsign, Tanh, Swish} from './activations'; +import {Elu, HardSigmoid, Linear, LogSoftmax, Relu, Relu6, Selu, Sigmoid, Softmax, Softplus, Softsign, Tanh, Swish, Mish} from './activations'; import {describeMathCPUAndGPU, expectNoLeakedTensors, expectTensorsClose} from './utils/test_utils'; describeMathCPUAndGPU('linear activation', () => { @@ -333,3 +333,36 @@ describeMathCPUAndGPU('swish activation', () => { expectNoLeakedTensors(() => swish(initX), 1); }); }); + +describeMathCPUAndGPU('mish activation', () => { + const mish = new Mish().apply; + // Setup: Array with initial values. + // Execute: Mish on the last dimension. + // Expect: Output array matches size and approximate expected values. + it('1D', () => { + const initX = tensor1d([0, 1, 3, 9]); + const expectedVals = tensor1d([0., .865, 2.987, 9.]); + expectTensorsClose(mish(initX), expectedVals); + }); + it('1D all equal', () => { + const initX = tensor1d([-1, -1, -1, -1]); + const expectedVals = tensor1d([-0.303, -0.303, -0.303, -0.303]); + expectTensorsClose(mish(initX), expectedVals); + }); + it('2D', () => { + const initX = tensor2d([[0, 1, 3, 9], [0, 1, 3, 9]]); + const expectedVals = tensor2d( + [[0., .865, 2.987, 9.], [0., .865, 2.987, 9.]]); + expectTensorsClose(mish(initX), expectedVals); + }); + it('3D', () => { + const initX = tensor3d([[[0, 1, 3, 9], [0, 1, 3, 9]]]); + const expectedVals = tensor3d( + [[[0., .865, 2.987, 9.], [0., .865, 2.987, 9.]]]); + expectTensorsClose(mish(initX), expectedVals); + }); + it('Does not leak', () => { + const initX = tensor1d([0, 1, 3, 9]); + expectNoLeakedTensors(() => mish(initX), 1); + }); +});