Skip to content
This repository was archived by the owner on Aug 15, 2019. It is now read-only.

Commit 0ae3c36

Browse files
LewuatheNikhil Thorat
authored andcommitted
Add Adadelta optimizer (#150)
* Add Adadelta optimizer * Post review update
1 parent 109574d commit 0ae3c36

File tree

3 files changed

+162
-0
lines changed

3 files changed

+162
-0
lines changed
Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
/**
2+
* @license
3+
* Copyright 2017 Google Inc. All Rights Reserved.
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
* =============================================================================
16+
*/
17+
18+
import {NDArrayMath} from '../../math/math';
19+
import {NDArray, Scalar} from '../../math/ndarray';
20+
import {Node} from '../graph';
21+
import {SessionRuntime} from '../session';
22+
import {SummedTensorArrayMap, TensorArrayMap} from '../tensor_array_map';
23+
24+
import {Optimizer} from './optimizer';
25+
26+
export class AdadeltaOptimizer extends Optimizer {
27+
constructor(
28+
protected learningRate: number, private gamma: number,
29+
specifiedVariableList?: Node[]) {
30+
super(learningRate, specifiedVariableList);
31+
this.eps = Scalar.new(1e-6);
32+
this.g = Scalar.new(this.gamma);
33+
}
34+
35+
beforeBatch(
36+
math: NDArrayMath, batchSize: number, runtime: SessionRuntime,
37+
activationArrayMap: TensorArrayMap,
38+
gradientArrayMap: SummedTensorArrayMap) {
39+
super.beforeBatch(
40+
math, batchSize, runtime, activationArrayMap, gradientArrayMap);
41+
if (this.accumulatedSquaredGradients.size() === 0) {
42+
this.variableNodes.forEach(node => {
43+
this.accumulatedSquaredGradients.set(
44+
node.output, NDArray.zeros(node.output.shape));
45+
this.accumulatedUpdates.set(
46+
node.output, NDArray.zeros(node.output.shape));
47+
});
48+
}
49+
}
50+
51+
afterBatch(
52+
math: NDArrayMath, batchSize: number, runtime: SessionRuntime,
53+
activationArrayMap: TensorArrayMap,
54+
gradientArrayMap: SummedTensorArrayMap) {
55+
math.scope((keep) => {
56+
this.variableNodes.forEach(node => {
57+
const oldVariable = activationArrayMap.get(node.output);
58+
const gradient = this.variableGradients.get(node.output);
59+
const oldCache = this.accumulatedSquaredGradients.get(node.output);
60+
const oldUpdates = this.accumulatedUpdates.get(node.output);
61+
62+
const gradientSquare = math.multiply(gradient, gradient);
63+
// Exponential decay of average squared gradients.
64+
const cache = math.scaledArrayAdd(
65+
this.g, oldCache, math.sub(this.one, this.g), gradientSquare);
66+
67+
const updates = math.multiply(
68+
math.divide(
69+
math.sqrt(math.add(oldUpdates, this.eps)),
70+
math.sqrt(math.add(oldCache, this.eps))),
71+
gradient);
72+
73+
const variable =
74+
math.scaledArrayAdd(this.c, updates, this.one, oldVariable);
75+
76+
const updateSquare = math.multiply(updates, updates);
77+
// Exponential decay of average updated values.
78+
const newUpdates = math.scaledArrayAdd(
79+
this.g, oldUpdates, math.sub(this.one, this.g), updateSquare);
80+
81+
this.accumulatedSquaredGradients.set(node.output, keep(cache));
82+
this.accumulatedUpdates.set(node.output, keep(newUpdates));
83+
activationArrayMap.set(node.output, keep(variable));
84+
node.data = variable;
85+
86+
oldVariable.dispose();
87+
oldCache.dispose();
88+
oldUpdates.dispose();
89+
});
90+
});
91+
92+
this.variableGradients.dispose();
93+
this.variableGradients = new TensorArrayMap();
94+
}
95+
96+
dispose() {
97+
super.dispose();
98+
this.eps.dispose();
99+
this.g.dispose();
100+
this.accumulatedSquaredGradients.dispose();
101+
this.accumulatedUpdates.dispose();
102+
}
103+
104+
private accumulatedSquaredGradients = new TensorArrayMap();
105+
private accumulatedUpdates = new TensorArrayMap();
106+
private eps: Scalar;
107+
private g: Scalar;
108+
}

src/graph/session_test.ts

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ import {AdagradOptimizer} from './optimizers/adagrad_optimizer';
2626
import {MomentumOptimizer} from './optimizers/momentum_optimizer';
2727
import {RMSPropOptimizer} from './optimizers/rmsprop_optimizer';
2828
import {SGDOptimizer} from './optimizers/sgd_optimizer';
29+
import {AdadeltaOptimizer} from './optimizers/adadelta_optimizer';
2930
import {FeedDictionary, FeedEntry, Session} from './session';
3031

3132

@@ -447,4 +448,56 @@ describe('Session', () => {
447448
session.train(w, [{tensor: x, data: inputProvider}], 1, optimizer))
448449
.toThrowError();
449450
});
451+
452+
it('adadelta', () => {
453+
const x = g.placeholder('x', [2]);
454+
const w = g.variable('w', NDArray.zeros([1, 2]));
455+
const b = g.variable('b', NDArray.zeros([1]));
456+
const y = g.reduceSum(g.add(g.matmul(w, x), b));
457+
458+
const safeMode = true;
459+
const optimizer = new AdadeltaOptimizer(0.1, 0.8);
460+
const math = new NDArrayMathCPU(safeMode);
461+
const session = new Session(g, math);
462+
const inputProvider: InputProvider = {
463+
getNextCopy() {
464+
return Array1D.new([2, 4]);
465+
},
466+
disposeCopy(math, example) {}
467+
};
468+
469+
math.scope(() => {
470+
// w = reduce_sum(w_1*x_1 + w_2*x_2 + b)
471+
// cache = [gamma*old_cache_w1 + (1-gamma)*grad_w1**2,
472+
// gamma*old_cache_w2 + (1-gamma)*grad_w2**2]
473+
// = [.8, 3.2]
474+
// updates = [sqrt(old_updates_w1 + eps)/sqrt(old_cache_w1 + eps)*grad_w1,
475+
// sqrt(old_updates_w2 + eps)/sqrT(old_cache_w2 + eps)*grad_w2]
476+
// = [2, 4]
477+
// w = [ w1_old - lr*updates_w1,
478+
// w2_old - lr*updates_w2]
479+
// = [-0.2, -0.4]
480+
// new_updates = [gamma * old_updates_w1 + (1 - gamma) * 2**2,
481+
// gamma * old_updates_w2 + (1 - gamma) * 4**2]
482+
// = [0.8, 3.2]
483+
//
484+
session.train(y, [{tensor: x, data: inputProvider}], 1, optimizer);
485+
const dydw = session.activationArrayMap.get(w).getValues();
486+
test_util.expectArraysClose(
487+
dydw, new Float32Array([-0.2, -0.4]), 1e-5);
488+
// cache = [gamma*old_cache_w1 + (1-gamma)*grad_w1**2,
489+
// gamma*old_cache_w2 + (1-gamma)*grad_w2**2]
490+
// = [1.44, 5.76]
491+
// updates = [sqrt(old_updates_w1 + eps)/sqrt(old_cache_w1 + eps)*grad_w1,
492+
// sqrt(old_updates_w2 + eps)/sqrT(old_cache_w2 + eps)*grad_w2]
493+
// = [2, 4]
494+
// w = [ w1_old - lr*updates_w1,
495+
// w2_old - lr*updates_w2]
496+
// = [-0.4, -0.8]
497+
session.train(y, [{tensor: x, data: inputProvider}], 1, optimizer);
498+
const dydw2 = session.activationArrayMap.get(w).getValues();
499+
test_util.expectArraysClose(
500+
dydw2, new Float32Array([-.4, -.8]), 2e-5);
501+
});
502+
});
450503
});

src/index.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ export {ENV, Features} from './environment';
3131
export {Graph, Tensor} from './graph/graph';
3232
export {AdagradOptimizer} from './graph/optimizers/adagrad_optimizer';
3333
export {MomentumOptimizer} from './graph/optimizers/momentum_optimizer';
34+
export {AdadeltaOptimizer} from './graph/optimizers/adadelta_optimizer';
3435
export {Optimizer} from './graph/optimizers/optimizer';
3536
export {RMSPropOptimizer} from './graph/optimizers/rmsprop_optimizer';
3637
export {SGDOptimizer} from './graph/optimizers/sgd_optimizer';

0 commit comments

Comments
 (0)