Skip to content
This repository has been archived by the owner on Aug 15, 2019. It is now read-only.

Commit

Permalink
Add Adam optimizer
Browse files Browse the repository at this point in the history
  • Loading branch information
Lewuathe committed Oct 4, 2017
1 parent 3477299 commit de94cee
Show file tree
Hide file tree
Showing 3 changed files with 191 additions and 0 deletions.
130 changes: 130 additions & 0 deletions src/graph/optimizers/adam_optimizer.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
/**
* @license
* Copyright 2017 Google Inc. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/

import {NDArrayMath} from '../../math/math';
import {NDArray, Scalar} from '../../math/ndarray';
import {Node} from '../graph';
import {SessionRuntime} from '../session';
import {SummedTensorArrayMap, TensorArrayMap} from '../tensor_array_map';

import {Optimizer} from './optimizer';

export class AdamOptimizer extends Optimizer {
constructor(
protected learningRate: number,
private beta1: number, private beta2: number,
specifiedVariableList?: Node[]) {
super(learningRate, specifiedVariableList);
this.eps = Scalar.new(1e-8);
// b1, b2 keep initial value of beta* hyperparameters.
this.b1 = Scalar.new(this.beta1);
this.b2 = Scalar.new(this.beta2);
// accB* will be updated by batch.
this.accB1 = Scalar.new(this.beta1);
this.accB2 = Scalar.new(this.beta2);
}

beforeBatch(
math: NDArrayMath, batchSize: number, runtime: SessionRuntime,
activationArrayMap: TensorArrayMap,
gradientArrayMap: SummedTensorArrayMap) {
super.beforeBatch(
math, batchSize, runtime, activationArrayMap, gradientArrayMap);

if (this.firstMoment.size() === 0) {
this.variableNodes.forEach(node => {
this.firstMoment.set(node.output, NDArray.zeros(node.output.shape));
});
}

if (this.secondMoment.size() === 0) {
this.variableNodes.forEach(node => {
this.secondMoment.set(node.output, NDArray.zeros(node.output.shape));
});
}
}

afterBatch(
math: NDArrayMath, batchSize: number, runtime: SessionRuntime,
activationArrayMap: TensorArrayMap,
gradientArrayMap: SummedTensorArrayMap) {
math.scope((keep) => {
this.variableNodes.forEach(node => {
const oldVariable = activationArrayMap.get(node.output);
const gradient = this.variableGradients.get(node.output);

const oldFirstMoment = this.firstMoment.get(node.output);
const oldSecondMoment = this.secondMoment.get(node.output);

const newFirstMoment = math.scaledArrayAdd(
this.b1, oldFirstMoment, math.sub(this.one, this.b1), gradient);
const gradientSquare = math.multiply(gradient, gradient);
const newSecondMoment = math.scaledArrayAdd(
this.b2, oldSecondMoment, math.sub(this.one, this.b2),
gradientSquare);

const biasCorrectedFirstMoment = math.divide(
newFirstMoment, math.sub(this.one, this.accB1));
const biasCorrectedSecondMoment = math.divide(
newSecondMoment, math.sub(this.one, this.accB2));

const variable = math.scaledArrayAdd(
this.c, math.divide(biasCorrectedFirstMoment,
math.add(math.sqrt(biasCorrectedSecondMoment), this.eps)),
this.one, oldVariable);
activationArrayMap.set(node.output, keep(variable));
node.data = variable;

this.firstMoment.set(node.output, keep(newFirstMoment));
this.secondMoment.set(node.output, keep(newSecondMoment));

oldVariable.dispose();
gradient.dispose();
oldFirstMoment.dispose();
oldSecondMoment.dispose();
});
// accB* represents beta1 and beta2 to
// the power t (the number of iteration).
this.accB1 = keep(math.multiply(this.accB1, this.b1));
this.accB2 = keep(math.multiply(this.accB2, this.b2));
});

this.variableGradients.dispose();
this.variableGradients = new TensorArrayMap();
}

dispose() {
super.dispose();
this.firstMoment.dispose();
this.secondMoment.dispose();
this.eps.dispose();
this.b1.dispose();
this.b2.dispose();
this.accB1.dispose();
this.accB2.dispose();
}

// Average of gradient
private firstMoment = new TensorArrayMap();
// Average of squared gradient
private secondMoment = new TensorArrayMap();
private eps: Scalar;
private b1: Scalar;
private b2: Scalar;
private accB1: Scalar;
private accB2: Scalar;
}
60 changes: 60 additions & 0 deletions src/graph/session_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import {MomentumOptimizer} from './optimizers/momentum_optimizer';
import {RMSPropOptimizer} from './optimizers/rmsprop_optimizer';
import {SGDOptimizer} from './optimizers/sgd_optimizer';
import {AdadeltaOptimizer} from './optimizers/adadelta_optimizer';
import {AdamOptimizer} from './optimizers/adam_optimizer';
import {FeedDictionary, FeedEntry, Session} from './session';


Expand Down Expand Up @@ -500,4 +501,63 @@ describe('Session', () => {
dydw2, new Float32Array([-.4, -.8]), 2e-5);
});
});

it('adam', () => {
const x = g.placeholder('x', [2]);
const w = g.variable('w', NDArray.zeros([1, 2]));
const b = g.variable('b', NDArray.zeros([1]));
const y = g.reduceSum(g.add(g.matmul(w, x), b));

const safeMode = true;
const optimizer = new AdamOptimizer(0.1, 0.8, 0.9);
const math = new NDArrayMathCPU(safeMode);
const session = new Session(g, math);
const inputProvider: InputProvider = {
getNextCopy() {
return Array1D.new([2, 4]);
},
disposeCopy(math, example) {}
};

math.scope(() => {
// w = reduce_sum(w_1*x_1 + w_2*x_2 + b)
// new_first_m = [beta1*old_first_m_w1 + (1-beta1)*grad_w1,
// beta1*old_first_m_w2 + (1-beta1)*grad_w2]
// = [.4, .8]
// new_second_m = [beta2*old_second_m_w1 + (1-beta2)*grad_w1**2,
// beta2*old_second_m_w2 + (1-beta2)*grad_w2**2]
// = [.4, 1.6]
// m = [new_first_m/(1-acc_beta1)] = [2, 4]
// v = [new_second_m/(1-acc_beta2)] = [4, 16]
// updates = [m_1/(sqrt(v_1) + eps),
// m_2/(sqrt(v_2) + eps)]
// = [1.0, 1.0]
// w = [ w1_old - lr*updates_1, w2_old - lr*updates_2]
// = [-0.1, -0.1]
//
session.train(y, [{tensor: x, data: inputProvider}], 1, optimizer);
const dydw = session.activationArrayMap.get(w).getValues();
test_util.expectArraysClose(
dydw, new Float32Array([-0.1, -0.1]), 1e-5);
// new_first_m = [beta1*old_first_m_w1 + (1-beta1)*grad_w1,
// beta1*old_first_m_w2 + (1-beta1)*grad_w2]
// = [0.8*0.4 + 0.2*2, 0.8*0.8 + 0.2*4]
// = [0.72, 1.44]
// new_second_m = [beta2*old_second_m_w1 + (1-beta2)*grad_w1**2,
// beta2*old_second_m_w2 + (1-beta2)*grad_w2**2]
// = [0.9*0.4 + 0.1*4, 0.9*1.6+0.1*16]
// = [0.76, 3.04]
// m = [new_first_m/(1-acc_beta1)] = [2, 4]
// v = [new_second_m/(1-acc_beta2)] = [4, 16]
// updates = [m_1/sqrt(v_1) + eps,
// m_2/sqrt(v_2) + eps]
// = [1.0, 1.0]
// w = [ w1_old - lr*updates_1, w2_old - lr*updates_2]
// = [-0.2, -0.2]
session.train(y, [{tensor: x, data: inputProvider}], 1, optimizer);
const dydw2 = session.activationArrayMap.get(w).getValues();
test_util.expectArraysClose(
dydw2, new Float32Array([-.2, -.2]), 2e-5);
});
});
});
1 change: 1 addition & 0 deletions src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ export {AdadeltaOptimizer} from './graph/optimizers/adadelta_optimizer';
export {Optimizer} from './graph/optimizers/optimizer';
export {RMSPropOptimizer} from './graph/optimizers/rmsprop_optimizer';
export {SGDOptimizer} from './graph/optimizers/sgd_optimizer';
export {AdamOptimizer} from './graph/optimizers/adam_optimizer';
export {CostReduction, FeedEntry, Session} from './graph/session';
// tslint:disable-next-line:max-line-length
export {GraphRunner, GraphRunnerEventObserver, MetricReduction} from './graph_runner';
Expand Down

0 comments on commit de94cee

Please sign in to comment.