Skip to content
This repository was archived by the owner on Aug 15, 2019. It is now read-only.

Commit dd90680

Browse files
authored
Fix truncated normal and expose test_util to public API (#172)
* fix truncated normal and expose test_util to public API * update ver
1 parent f4a39f4 commit dd90680

File tree

4 files changed

+34
-11
lines changed

4 files changed

+34
-11
lines changed

package.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
{
22
"name": "deeplearn",
3-
"version": "0.2.7",
3+
"version": "0.2.8",
44
"description": "Hardware-accelerated JavaScript library for machine intelligence",
55
"private": false,
66
"main": "dist/src/index.js",

src/index.ts

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ import * as conv_util from './math/conv_util';
2020
import * as gpgpu_util from './math/webgl/gpgpu_util';
2121
import * as render_ndarray_gpu_util from './math/webgl/render_ndarray_gpu_util';
2222
import * as webgl_util from './math/webgl/webgl_util';
23+
import * as test_util from './test_util';
2324
import * as util from './util';
2425

2526
export {CheckpointLoader} from './data/checkpoint_loader';
@@ -29,9 +30,9 @@ export {InCPUMemoryShuffledInputProviderBuilder, InGPUMemoryShuffledInputProvide
2930
export {XhrDataset, XhrDatasetConfig, XhrModelConfig} from './data/xhr-dataset';
3031
export {ENV, Features} from './environment';
3132
export {Graph, Tensor} from './graph/graph';
33+
export {AdadeltaOptimizer} from './graph/optimizers/adadelta_optimizer';
3234
export {AdagradOptimizer} from './graph/optimizers/adagrad_optimizer';
3335
export {MomentumOptimizer} from './graph/optimizers/momentum_optimizer';
34-
export {AdadeltaOptimizer} from './graph/optimizers/adadelta_optimizer';
3536
export {Optimizer} from './graph/optimizers/optimizer';
3637
export {RMSPropOptimizer} from './graph/optimizers/rmsprop_optimizer';
3738
export {SGDOptimizer} from './graph/optimizers/sgd_optimizer';
@@ -51,6 +52,7 @@ export {
5152
conv_util,
5253
gpgpu_util,
5354
render_ndarray_gpu_util,
55+
test_util,
5456
util,
5557
webgl_util,
5658
xhr_dataset

src/util.ts

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,8 @@
1515
* =============================================================================
1616
*/
1717

18-
export type Vector = number[] | Float64Array | Float32Array | Int32Array |
19-
Int8Array | Int16Array;
18+
export type Vector =
19+
number[]|Float64Array|Float32Array|Int32Array|Int8Array|Int16Array;
2020

2121
/** Shuffles the array using Fisher-Yates algorithm. */
2222
// tslint:disable-next-line:no-any
@@ -63,7 +63,7 @@ export function randGauss(mean = 0, stdDev = 1, truncated = false): number {
6363
} while (s > 1);
6464

6565
const result = Math.sqrt(-2 * Math.log(s) / s) * v1;
66-
if (truncated && result > 2) {
66+
if (truncated && Math.abs(result) > 2) {
6767
return randGauss(mean, stdDev, true);
6868
}
6969
return mean + stdDev * result;
@@ -106,12 +106,7 @@ export function flatten(arr: any[], ret?: number[]): number[] {
106106
}
107107

108108
export type ArrayData =
109-
Float32Array |
110-
number |
111-
number[] |
112-
number[][] |
113-
number[][][] |
114-
number[][][][];
109+
Float32Array|number|number[]|number[][]|number[][][]|number[][][][];
115110

116111
export function inferShape(arr: ArrayData): number[] {
117112
const shape: number[] = [];

src/util_test.ts

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -218,3 +218,29 @@ describe('util.inferFromImplicitShape', () => {
218218
expect(() => util.inferFromImplicitShape([2, 3, 4], 25)).toThrowError();
219219
});
220220
});
221+
222+
describe('util.randGauss', () => {
223+
it('standard normal', () => {
224+
const a = util.randGauss();
225+
expect(a != null);
226+
});
227+
228+
it('truncated standard normal', () => {
229+
const numSamples = 1000;
230+
for (let i = 0; i < numSamples; ++i) {
231+
const sample = util.randGauss(0, 1, true);
232+
expect(Math.abs(sample) <= 2);
233+
}
234+
});
235+
236+
it('truncated normal, mu = 3, std=4', () => {
237+
const numSamples = 1000;
238+
const mean = 3;
239+
const stdDev = 4;
240+
for (let i = 0; i < numSamples; ++i) {
241+
const sample = util.randGauss(mean, stdDev, true);
242+
const normalizedSample = (sample - mean) / stdDev;
243+
expect(Math.abs(normalizedSample) <= 2);
244+
}
245+
});
246+
});

0 commit comments

Comments
 (0)