Skip to content
This repository was archived by the owner on Aug 15, 2019. It is now read-only.

Commit f0256e6

Browse files
authored
Add dl.concat/stack/expandDims/pad and tensor.buffer() (#708)
- Align `dl.concat` to match `tf.concat` (takes a list of tensors, instead of just two tensors) - Add `dl.stack` (and its chain sibling) - Add `dl.expandDims` (and its chain sibling) - Add `dl.pad` (and its chain sibling). Also have `dl.pad1d/2d` call into `dl.pad`. - Deprecate `tensor.val()/get()/locToIndex()/indexToVal()` - Add `tensor.buffer()` which returns `TensorBuffer` holding the underlying data. Fixes #696, #706 and #705
1 parent 3032c10 commit f0256e6

File tree

10 files changed

+474
-128
lines changed

10 files changed

+474
-128
lines changed

karma.conf.js

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ module.exports = function(config) {
3434
username: process.env.BROWSERSTACK_USERNAME,
3535
accessKey: process.env.BROWSERSTACK_KEY
3636
},
37+
reportSlowerThan: 500,
3738
browserNoActivityTimeout: 30000,
3839
customLaunchers: {
3940
bs_chrome_mac: {

src/math.ts

Lines changed: 26 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@ import * as array_ops from './ops/array_ops';
2121
import * as batchnorm from './ops/batchnorm';
2222
import * as binary_ops from './ops/binary_ops';
2323
import * as compare from './ops/compare';
24-
import * as concat from './ops/concat';
2524
import * as conv from './ops/conv';
2625
import * as image_ops from './ops/image_ops';
2726
import * as logical from './ops/logical_ops';
@@ -38,7 +37,7 @@ import * as softmax_ops from './ops/softmax';
3837
import * as transpose from './ops/transpose';
3938
import * as unary_ops from './ops/unary_ops';
4039
import {ScopeResult} from './tape_util';
41-
import {Scalar, Tensor, Tensor1D} from './tensor';
40+
import {Scalar, Tensor, Tensor1D, Tensor2D, Tensor3D, Tensor4D} from './tensor';
4241
import {Tracking} from './tracking';
4342
import {Rank} from './types';
4443
import * as util from './util';
@@ -66,12 +65,6 @@ export class NDArrayMath {
6665
reverse3D = reverse.Ops.reverse3d;
6766
reverse4D = reverse.Ops.reverse4d;
6867

69-
concat = concat.Ops.concat;
70-
concat1D = concat.Ops.concat1d;
71-
concat2D = concat.Ops.concat2d;
72-
concat3D = concat.Ops.concat3d;
73-
concat4D = concat.Ops.concat4d;
74-
7568
batchNormalization = batchnorm.Ops.batchNormalization;
7669
batchNormalization2D = batchnorm.Ops.batchNormalization2d;
7770
batchNormalization3D = batchnorm.Ops.batchNormalization3d;
@@ -329,6 +322,31 @@ export class NDArrayMath {
329322
`got rank ${c.rank}.`);
330323
return this.multiply(c, a) as T;
331324
}
325+
326+
/** @deprecated */
327+
concat<T extends Tensor>(a: T, b: T, axis: number): T {
328+
return ops.concat([a, b], axis);
329+
}
330+
331+
/** @deprecated */
332+
concat1D(a: Tensor1D, b: Tensor1D): Tensor1D {
333+
return ops.concat1d([a, b]);
334+
}
335+
336+
/** @deprecated */
337+
concat2D(a: Tensor2D, b: Tensor2D, axis: number): Tensor2D {
338+
return ops.concat2d([a, b], axis);
339+
}
340+
341+
/** @deprecated */
342+
concat3D(a: Tensor3D, b: Tensor3D, axis: number): Tensor3D {
343+
return ops.concat3d([a, b], axis);
344+
}
345+
346+
/** @deprecated */
347+
concat4D(a: Tensor4D, b: Tensor4D, axis: number): Tensor4D {
348+
return ops.concat4d([a, b], axis);
349+
}
332350
}
333351

334352
export type ScopeFn<T extends ScopeResult> =

src/ops/array_ops.ts

Lines changed: 94 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,9 @@ import {ENV} from '../environment';
2020
// tslint:disable-next-line:max-line-length
2121
import {Scalar, Tensor, Tensor1D, Tensor2D, Tensor3D, Tensor4D, TensorBuffer} from '../tensor';
2222
// tslint:disable-next-line:max-line-length
23-
import {ArrayData, DataType, DataTypeMap, Rank, ShapeMap, TensorLike, TensorLike1D, TensorLike2D, TensorLike3D, TensorLike4D} from '../types';
23+
import {ArrayData, DataType, DataTypeMap, Rank, ShapeMap, TensorLike, TensorLike1D, TensorLike2D, TensorLike3D, TensorLike4D, TypedArray} from '../types';
2424
import * as util from '../util';
25-
25+
import {Concat} from './concat';
2626
import {operation} from './operation';
2727
import {MPRandGauss} from './rand';
2828

@@ -595,17 +595,16 @@ export class Ops {
595595
}
596596

597597
/**
598-
* Pads a Tensor1D with a given value.
598+
* Pads a `Tensor1D` with a given value.
599599
*
600600
* This operation will pad a tensor according to the `paddings` you specify.
601601
*
602602
* This operation currently only implements the `CONSTANT` mode from
603603
* Tensorflow's `pad` operation.
604604
*
605605
* @param x The tensor to pad.
606-
* @param paddings A tuple of ints [padLeft, padRight], how much to pad on the
607-
* left and right side of the tensor.
608-
* @param constantValue The scalar pad value to use. Defaults to 0.
606+
* @param paddings A tuple of ints `[padLeft, padRight]`, how much to pad.
607+
* @param constantValue The pad value to use. Defaults to 0.
609608
*/
610609
@operation
611610
static pad1d(x: Tensor1D, paddings: [number, number], constantValue = 0):
@@ -618,18 +617,15 @@ export class Ops {
618617
}
619618

620619
/**
621-
* Pads a Tensor2D with a given value.
622-
*
623-
* This operation will pad a tensor according to the `paddings` you specify.
620+
* Pads a `Tensor2D` with a given value and the `paddings` you specify.
624621
*
625622
* This operation currently only implements the `CONSTANT` mode from
626-
* Tensorflow's `pad` operation.
623+
* TensorFlow's `pad` operation.
627624
*
628625
* @param x The tensor to pad.
629-
* @param paddings A pair of tuple ints
630-
* [[padTop, padBottom], [padLeft, padRight]], how much to pad on the
631-
* tensor.
632-
* @param constantValue The scalar pad value to use. Defaults to 0.
626+
* @param paddings A pair of tuple ints:
627+
* `[[padTop, padBottom], [padLeft, padRight]]`, how much to pad.
628+
* @param constantValue The pad value to use. Defaults to 0.
633629
*/
634630
@operation
635631
static pad2d(
@@ -643,6 +639,83 @@ export class Ops {
643639
'Pad2D', {inputs: {x}, args: {paddings, constantValue}});
644640
}
645641

642+
/**
643+
* Pads a `Tensor` with a given value and the `paddings` you specify.
644+
*
645+
* This operation currently only implements the `CONSTANT` mode from
646+
* Tensorflow's `pad` operation.
647+
*
648+
* @param x The tensor to pad.
649+
* @param paddings An array of length `R` (the rank of the tensor), where each
650+
* element is a length-2 tuple of ints `[padBefore, padAfter]`, specifying
651+
* how much to pad along each dimension of the tensor.
652+
* @param constantValue The pad value to use. Defaults to 0.
653+
*/
654+
@doc({heading: 'Tensors', subheading: 'Transformations'})
655+
@operation
656+
static pad<T extends Tensor>(
657+
x: T, paddings: Array<[number, number]>, constantValue = 0): T {
658+
if (x.rank === 0) {
659+
throw new Error('pad(scalar) is not defined. Pass non-scalar to pad');
660+
} else if (x.rank === 1) {
661+
return Ops.pad1d(x as Tensor1D, paddings[0], constantValue) as T;
662+
} else if (x.rank === 2) {
663+
return Ops.pad2d(
664+
x as Tensor2D,
665+
paddings as [[number, number], [number, number]],
666+
constantValue) as T;
667+
} else {
668+
throw new Error(`pad of rank-${x.rank} tensor is not yet supported`);
669+
}
670+
}
671+
672+
/**
673+
* Stacks a list of rank-`R` `Tensor`s into one rank-`(R+1)` `Tensor`.
674+
*
675+
* @param tensors A list of tensor objects with the same shape and dtype.
676+
* @param axis The axis to stack along. Defaults to 0 (the first dim).
677+
*/
678+
@doc({heading: 'Tensors', subheading: 'Transformations'})
679+
@operation
680+
static stack<T extends Tensor>(tensors: T[], axis = 0): Tensor {
681+
util.assert(tensors.length >= 2, 'Pass at least two tensors to dl.stack');
682+
const rank = tensors[0].rank;
683+
const shape = tensors[0].shape;
684+
const dtype = tensors[0].dtype;
685+
686+
util.assert(axis <= rank, 'Axis must be <= rank of the tensor');
687+
688+
tensors.forEach(t => {
689+
util.assertShapesMatch(
690+
shape, t.shape,
691+
'All tensors passed to stack must have matching shapes');
692+
});
693+
694+
tensors.forEach(t => {
695+
util.assert(
696+
dtype === t.dtype,
697+
'All tensors passed to stack must have matching dtypes');
698+
});
699+
const expandedTensors = tensors.map(t => t.expandDims(axis));
700+
return Concat.concat(expandedTensors, axis);
701+
}
702+
703+
/**
704+
* Returns a `Tensor` that has expanded rank, by inserting a dimension
705+
* into the tensor's shape.
706+
*
707+
* @param axis The dimension index at which to insert shape of `1`. Defaults
708+
* to 0 (the first dimension).
709+
*/
710+
@doc({heading: 'Tensors', subheading: 'Transformations'})
711+
@operation
712+
static expandDims<R2 extends Rank>(x: Tensor, axis = 0): Tensor<R2> {
713+
util.assert(axis <= x.rank, 'Axis must be <= rank of the tensor');
714+
const newShape = x.shape.slice();
715+
newShape.splice(axis, 0, 1);
716+
return Ops.reshape(x, newShape);
717+
}
718+
646719
/**
647720
* Return an evenly spaced sequence of numbers over the given interval.
648721
*
@@ -720,18 +793,19 @@ export class Ops {
720793
/**
721794
* Creates an empty `TensorBuffer` with the specified `shape` and `dtype`.
722795
*
723-
* The values are stored in cpu as a `TypedArray`. Fill the buffer using
724-
* `buffer.set()`, or by modifying directly `buffer.values`.
796+
* The values are stored in cpu as `TypedArray`. Fill the buffer using
797+
* `buffer.set()`, or by modifying directly `buffer.values`. When done,
798+
* call `buffer.toTensor()` to get an immutable `Tensor` with those values.
725799
*
726-
* When done, call `buffer.toTensor()` to get an immutable `Tensor` with those
727-
* values.
728800
* @param shape An array of integers defining the output tensor shape.
729801
* @param dtype The dtype of the buffer. Defaults to 'float32'.
802+
* @param values The values of the buffer as `TypedArray`. Defaults to zeros.
730803
*/
731804
@doc({heading: 'Tensors', subheading: 'Creation'})
732805
static buffer<R extends Rank>(
733-
shape: ShapeMap[R], dtype: DataType = 'float32'): TensorBuffer<R> {
734-
return new TensorBuffer<R>(shape, dtype);
806+
shape: ShapeMap[R], dtype: DataType = 'float32', values?: TypedArray):
807+
TensorBuffer<R> {
808+
return new TensorBuffer<R>(shape, dtype, values);
735809
}
736810

737811
/**

src/ops/array_ops_test.ts

Lines changed: 119 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,9 @@
1717

1818
import * as dl from '../index';
1919
// tslint:disable-next-line:max-line-length
20-
import {expectArrayInMeanStdRange, jarqueBeraNormalityTest} from './rand_util';
21-
// tslint:disable-next-line:max-line-length
2220
import {ALL_ENVS, describeWithFlags, expectArraysClose, expectArraysEqual, expectValuesInRange} from '../test_util';
2321
import * as util from '../util';
22+
import {expectArrayInMeanStdRange, jarqueBeraNormalityTest} from './rand_util';
2423

2524
describeWithFlags('zeros', ALL_ENVS, () => {
2625
it('1D default dtype', () => {
@@ -1632,3 +1631,121 @@ describeWithFlags('fill', ALL_ENVS, () => {
16321631
expectArraysClose(a, [2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2]);
16331632
});
16341633
});
1634+
1635+
describeWithFlags('stack', ALL_ENVS, () => {
1636+
it('scalars 3, 5 and 7', () => {
1637+
const a = dl.scalar(3);
1638+
const b = dl.scalar(5);
1639+
const c = dl.scalar(7);
1640+
const res = dl.stack([a, b, c]);
1641+
expect(res.shape).toEqual([3]);
1642+
expectArraysClose(res, [3, 5, 7]);
1643+
});
1644+
1645+
it('scalars 3, 5 and 7 along axis=1 throws error', () => {
1646+
const a = dl.scalar(3);
1647+
const b = dl.scalar(5);
1648+
const c = dl.scalar(7);
1649+
const f = () => dl.stack([a, b, c], 1);
1650+
expect(f).toThrowError();
1651+
});
1652+
1653+
it('non matching shapes throws error', () => {
1654+
const a = dl.scalar(3);
1655+
const b = dl.tensor1d([5]);
1656+
const f = () => dl.stack([a, b]);
1657+
expect(f).toThrowError();
1658+
});
1659+
1660+
it('non matching dtypes throws error', () => {
1661+
const a = dl.scalar(3);
1662+
const b = dl.scalar(5, 'bool');
1663+
const f = () => dl.stack([a, b]);
1664+
expect(f).toThrowError();
1665+
});
1666+
1667+
it('2d but axis=3 throws error', () => {
1668+
const a = dl.zeros([2, 2]);
1669+
const b = dl.zeros([2, 2]);
1670+
const f = () => dl.stack([a, b], 3 /* axis */);
1671+
expect(f).toThrowError();
1672+
});
1673+
1674+
it('[1,2], [3,4] and [5,6], axis=0', () => {
1675+
const a = dl.tensor1d([1, 2]);
1676+
const b = dl.tensor1d([3, 4]);
1677+
const c = dl.tensor1d([5, 6]);
1678+
const res = dl.stack([a, b, c], 0 /* axis */);
1679+
expect(res.shape).toEqual([3, 2]);
1680+
expectArraysClose(res, [1, 2, 3, 4, 5, 6]);
1681+
});
1682+
1683+
it('[1,2], [3,4] and [5,6], axis=1', () => {
1684+
const a = dl.tensor1d([1, 2]);
1685+
const b = dl.tensor1d([3, 4]);
1686+
const c = dl.tensor1d([5, 6]);
1687+
const res = dl.stack([a, b, c], 1 /* axis */);
1688+
expect(res.shape).toEqual([2, 3]);
1689+
expectArraysClose(res, [1, 3, 5, 2, 4, 6]);
1690+
});
1691+
1692+
it('[[1,2],[3,4]] and [[5, 6], [7, 8]], axis=0', () => {
1693+
const a = dl.tensor2d([[1, 2], [3, 4]]);
1694+
const b = dl.tensor2d([[5, 6], [7, 8]]);
1695+
const res = dl.stack([a, b], 0 /* axis */);
1696+
expect(res.shape).toEqual([2, 2, 2]);
1697+
expectArraysClose(res, [1, 2, 3, 4, 5, 6, 7, 8]);
1698+
});
1699+
1700+
it('[[1,2],[3,4]] and [[5, 6], [7, 8]], axis=2', () => {
1701+
const a = dl.tensor2d([[1, 2], [3, 4]]);
1702+
const b = dl.tensor2d([[5, 6], [7, 8]]);
1703+
const c = dl.tensor2d([[9, 10], [11, 12]]);
1704+
const res = dl.stack([a, b, c], 2 /* axis */);
1705+
expect(res.shape).toEqual([2, 2, 3]);
1706+
expectArraysClose(res, [1, 5, 9, 2, 6, 10, 3, 7, 11, 4, 8, 12]);
1707+
});
1708+
});
1709+
1710+
describeWithFlags('expandDims', ALL_ENVS, () => {
1711+
it('scalar, default axis is 0', () => {
1712+
const res = dl.scalar(1).expandDims();
1713+
expect(res.shape).toEqual([1]);
1714+
expectArraysClose(res, [1]);
1715+
});
1716+
1717+
it('scalar, axis is out of bounds throws error', () => {
1718+
const f = () => dl.scalar(1).expandDims(1);
1719+
expect(f).toThrowError();
1720+
});
1721+
1722+
it('1d, axis=0', () => {
1723+
const res = dl.tensor1d([1, 2, 3]).expandDims(0 /* axis */);
1724+
expect(res.shape).toEqual([1, 3]);
1725+
expectArraysClose(res, [1, 2, 3]);
1726+
});
1727+
1728+
it('1d, axis=1', () => {
1729+
const res = dl.tensor1d([1, 2, 3]).expandDims(1 /* axis */);
1730+
expect(res.shape).toEqual([3, 1]);
1731+
expectArraysClose(res, [1, 2, 3]);
1732+
});
1733+
1734+
it('2d, axis=0', () => {
1735+
const res = dl.tensor2d([[1, 2], [3, 4], [5, 6]]).expandDims(0 /* axis */);
1736+
expect(res.shape).toEqual([1, 3, 2]);
1737+
expectArraysClose(res, [1, 2, 3, 4, 5, 6]);
1738+
});
1739+
1740+
it('2d, axis=1', () => {
1741+
const res = dl.tensor2d([[1, 2], [3, 4], [5, 6]]).expandDims(1 /* axis */);
1742+
expect(res.shape).toEqual([3, 1, 2]);
1743+
expectArraysClose(res, [1, 2, 3, 4, 5, 6]);
1744+
});
1745+
1746+
it('2d, axis=2', () => {
1747+
const res = dl.tensor2d([[1, 2], [3, 4], [5, 6]]).expandDims(2 /* axis */);
1748+
expect(res.shape).toEqual([3, 2, 1]);
1749+
expectArraysClose(res, [1, 2, 3, 4, 5, 6]);
1750+
});
1751+
});

0 commit comments

Comments
 (0)