Skip to content
This repository was archived by the owner on Aug 15, 2019. It is now read-only.

Commit c25b259

Browse files
authored
Align dl.matmul with tf.matmul (#711)
Use booleans `transposeA/B` instead of the enum `MatrixOrientation`. Support the enum for backwards compatibility. Fixes #697
1 parent f0256e6 commit c25b259

File tree

11 files changed

+100
-131
lines changed

11 files changed

+100
-131
lines changed

src/engine_test.ts

Lines changed: 18 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
*/
1717

1818
import * as dl from './index';
19-
import {MatrixOrientation} from './kernels/types/matmul';
2019
import {Tensor} from './tensor';
2120
// tslint:disable-next-line:max-line-length
2221
import {ALL_ENVS, describeWithFlags, expectArraysClose, expectArraysEqual, expectNumbersClose} from './test_util';
@@ -217,19 +216,15 @@ describeWithFlags('gradients', ALL_ENVS, () => {
217216

218217
// de/da = dot(de/dy, bT)
219218
expect(da.shape).toEqual(a.shape);
220-
expectArraysClose(
221-
da,
222-
dl.matMul(
223-
dedm, b, MatrixOrientation.REGULAR, MatrixOrientation.TRANSPOSED),
224-
1e-1);
219+
let transposeA = false;
220+
let transposeB = true;
221+
expectArraysClose(da, dl.matMul(dedm, b, transposeA, transposeB), 1e-1);
225222

226223
// de/db = dot(aT, de/dy)
227224
expect(db.shape).toEqual(b.shape);
228-
expectArraysClose(
229-
db,
230-
dl.matMul(
231-
a, dedm, MatrixOrientation.TRANSPOSED, MatrixOrientation.REGULAR),
232-
1e-1);
225+
transposeA = true;
226+
transposeB = false;
227+
expectArraysClose(db, dl.matMul(a, dedm, transposeA, transposeB), 1e-1);
233228
});
234229

235230
it('grad(f)', () => {
@@ -326,18 +321,14 @@ describeWithFlags('valueAndGradients', ALL_ENVS, () => {
326321

327322
const [da, db] = grads;
328323
// de/da = dot(de/dy, bT)
329-
expectArraysClose(
330-
da,
331-
dl.matMul(
332-
dedm, b, MatrixOrientation.REGULAR, MatrixOrientation.TRANSPOSED),
333-
1e-1);
324+
let transposeA = false;
325+
let transposeB = true;
326+
expectArraysClose(da, dl.matMul(dedm, b, transposeA, transposeB), 1e-1);
334327

335328
// de/db = dot(aT, de/dy)
336-
expectArraysClose(
337-
db,
338-
dl.matMul(
339-
a, dedm, MatrixOrientation.TRANSPOSED, MatrixOrientation.REGULAR),
340-
1e-1);
329+
transposeA = true;
330+
transposeB = false;
331+
expectArraysClose(db, dl.matMul(a, dedm, transposeA, transposeB), 1e-1);
341332
});
342333

343334
it('matmul + relu + inner tidy', () => {
@@ -365,18 +356,14 @@ describeWithFlags('valueAndGradients', ALL_ENVS, () => {
365356

366357
const [da, db] = grads;
367358
// de/da = dot(de/dy, bT)
368-
expectArraysClose(
369-
da,
370-
dl.matMul(
371-
dedm, b, MatrixOrientation.REGULAR, MatrixOrientation.TRANSPOSED),
372-
1e-1);
359+
let transposeA = false;
360+
let transposeB = true;
361+
expectArraysClose(da, dl.matMul(dedm, b, transposeA, transposeB), 1e-1);
373362

374363
// de/db = dot(aT, de/dy)
375-
expectArraysClose(
376-
db,
377-
dl.matMul(
378-
a, dedm, MatrixOrientation.TRANSPOSED, MatrixOrientation.REGULAR),
379-
1e-1);
364+
transposeA = true;
365+
transposeB = false;
366+
expectArraysClose(db, dl.matMul(a, dedm, transposeA, transposeB), 1e-1);
380367
});
381368
});
382369

src/graph/ops/matmul.ts

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
*/
1717

1818
import {keep, tidy} from '../../globals';
19-
import {MatrixOrientation} from '../../kernels/types/matmul';
2019
import {NDArrayMath} from '../../math';
2120
import {Tensor1D, Tensor2D} from '../../tensor';
2221
import {SymbolicTensor} from '../graph';
@@ -75,16 +74,12 @@ export class MatMul extends Operation {
7574
// dx1 = dy * x2T
7675
// dx2 = x1T * dy
7776
if (graph_util.shouldBackProp(this.x1Tensor)) {
78-
const dx1 = math.matMul(
79-
dy as Tensor2D, x2 as Tensor2D, MatrixOrientation.REGULAR,
80-
MatrixOrientation.TRANSPOSED);
77+
const dx1 = math.matMul(dy as Tensor2D, x2 as Tensor2D, false, true);
8178
gradientArrays.add(
8279
this.x1Tensor, this.x1Tensor.shape.length === 1 ? dx1.as1D() : dx1);
8380
}
8481
if (graph_util.shouldBackProp(this.x2Tensor)) {
85-
const dx2 = math.matMul(
86-
x1 as Tensor2D, dy as Tensor2D, MatrixOrientation.TRANSPOSED,
87-
MatrixOrientation.REGULAR);
82+
const dx2 = math.matMul(x1 as Tensor2D, dy as Tensor2D, true, false);
8883
gradientArrays.add(
8984
this.x2Tensor, this.x2Tensor.shape.length === 1 ? dx2.as1D() : dx2);
9085
}

src/kernels/backend.ts

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,6 @@ import {Conv2DInfo} from '../ops/conv_util';
2121
import {DataId, Tensor, Tensor1D, Tensor2D, Tensor3D, Tensor4D} from '../tensor';
2222
import {DataType, Rank, TypedArray} from '../types';
2323

24-
import {MatrixOrientation} from './types/matmul';
25-
2624
export interface TensorStorage {
2725
read(dataId: DataId): Promise<TypedArray>;
2826
readSync(dataId: DataId): TypedArray;
@@ -45,9 +43,8 @@ export interface BackendTimer { time(f: () => void): Promise<number>; }
4543
* methods).
4644
*/
4745
export interface KernelBackend extends TensorStorage, BackendTimer {
48-
matMul(
49-
a: Tensor2D, b: Tensor2D, aOrientation: MatrixOrientation,
50-
bOrientation: MatrixOrientation): Tensor2D;
46+
matMul(a: Tensor2D, b: Tensor2D, transposeA: boolean, transposeB: boolean):
47+
Tensor2D;
5148

5249
slice1D(x: Tensor1D, begin: number, size: number): Tensor1D;
5350
slice2D(x: Tensor2D, begin: [number, number], size: [number, number]):

src/kernels/backend_cpu.ts

Lines changed: 8 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
*/
1717

1818
import * as seedrandom from 'seedrandom';
19-
2019
import {ENV} from '../environment';
2120
import {NDArrayMath} from '../math';
2221
import * as axis_util from '../ops/axis_util';
@@ -31,9 +30,7 @@ import {DataId, Tensor, Tensor1D, Tensor2D, Tensor3D, Tensor4D} from '../tensor'
3130
import * as types from '../types';
3231
import {DataType, DataTypeMap, Rank, TypedArray} from '../types';
3332
import * as util from '../util';
34-
3533
import {KernelBackend} from './backend';
36-
import {MatrixOrientation} from './types/matmul';
3734

3835
export class MathBackendCPU implements KernelBackend {
3936
private data = new WeakMap<DataId, DataTypeMap[DataType]>();
@@ -269,28 +266,20 @@ export class MathBackendCPU implements KernelBackend {
269266
T;
270267
}
271268

272-
matMul(
273-
a: Tensor2D, b: Tensor2D, aOrientation = MatrixOrientation.REGULAR,
274-
bOrientation = MatrixOrientation.REGULAR): Tensor2D {
275-
const sharedDim =
276-
(aOrientation === MatrixOrientation.REGULAR) ? a.shape[1] : a.shape[0];
277-
278-
const leftDim =
279-
(aOrientation === MatrixOrientation.REGULAR) ? a.shape[0] : a.shape[1];
280-
const rightDim =
281-
(bOrientation === MatrixOrientation.REGULAR) ? b.shape[1] : b.shape[0];
269+
matMul(a: Tensor2D, b: Tensor2D, transposeA: boolean, transposeB: boolean):
270+
Tensor2D {
271+
const sharedDim = transposeA ? a.shape[0] : a.shape[1];
272+
const leftDim = transposeA ? a.shape[1] : a.shape[0];
273+
const rightDim = transposeB ? b.shape[0] : b.shape[1];
282274

283275
const normalGetter = (matrix: Tensor2D, i: number, j: number) =>
284276
matrix.get(i, j);
285277
const transposedGetter = (matrix: Tensor2D, i: number, j: number) =>
286278
matrix.get(j, i);
287279

288-
const aGetter = (aOrientation === MatrixOrientation.REGULAR) ?
289-
normalGetter :
290-
transposedGetter;
291-
const bGetter = (bOrientation === MatrixOrientation.REGULAR) ?
292-
normalGetter :
293-
transposedGetter;
280+
const aGetter = transposeA ? transposedGetter : normalGetter;
281+
const bGetter = transposeB ? transposedGetter : normalGetter;
282+
294283
const values = new Float32Array(leftDim * rightDim);
295284
let index = 0;
296285
for (let i = 0; i < leftDim; ++i) {

src/kernels/backend_webgl.ts

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,7 @@ import * as types from '../types';
2626
// tslint:disable-next-line:max-line-length
2727
import {DataType, DataTypeMap, Rank, RecursiveArray, TypedArray} from '../types';
2828
import * as util from '../util';
29-
3029
import {KernelBackend} from './backend';
31-
import {MatrixOrientation} from './types/matmul';
3230
import {ArgMinMaxProgram} from './webgl/argminmax_gpu';
3331
import {AvgPool2DBackpropProgram} from './webgl/avg_pool_backprop_gpu';
3432
import {BatchNormProgram} from './webgl/batchnorm_gpu';
@@ -330,11 +328,9 @@ export class MathBackendWebGL implements KernelBackend {
330328
return this.compileAndRun(program, [x]) as T;
331329
}
332330

333-
matMul(
334-
a: Tensor2D, b: Tensor2D, aOrientation: MatrixOrientation,
335-
bOrientation: MatrixOrientation): Tensor2D {
336-
const program =
337-
new MatMulProgram(a.shape, b.shape, aOrientation, bOrientation);
331+
matMul(a: Tensor2D, b: Tensor2D, transposeA: boolean, transposeB: boolean):
332+
Tensor2D {
333+
const program = new MatMulProgram(a.shape, b.shape, transposeA, transposeB);
338334
return this.compileAndRun<Tensor2D, Tensor2D>(program, [a, b]);
339335
}
340336

src/kernels/kernel_registry.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,8 +59,8 @@ executeKernel<R extends Rank, K extends keyof KernelConfigRegistry<R>, O extends
5959
if (kernelName === 'MatMul') {
6060
const config = inputAndArgs as MatMulNode['inputAndArgs'];
6161
return backend.matMul(
62-
config.inputs.a, config.inputs.b, config.args.aOrientation,
63-
config.args.bOrientation) as O;
62+
config.inputs.a, config.inputs.b, config.args.transposeA,
63+
config.args.transposeB) as O;
6464
} else if (kernelName === 'Slice1D') {
6565
const config = inputAndArgs as Slice1DNode['inputAndArgs'];
6666
return backend.slice1D(

src/kernels/types/matmul.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ import {Tensor2D} from '../../tensor';
2121
export interface MatMulNode extends KernelNode {
2222
inputAndArgs: {
2323
inputs: {a: Tensor2D; b: Tensor2D;};
24-
args: {aOrientation: MatrixOrientation; bOrientation: MatrixOrientation};
24+
args: {transposeA: boolean; transposeB: boolean};
2525
};
2626
output: Tensor2D;
2727
gradient: (dy: Tensor2D, y: Tensor2D) => {

src/kernels/webgl/mulmat_gpu.ts

Lines changed: 9 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
* =============================================================================
1616
*/
1717

18-
import {MatrixOrientation} from '../types/matmul';
1918
import {GPGPUProgram} from './gpgpu_math';
2019

2120
export class MatMulProgram implements GPGPUProgram {
@@ -24,25 +23,19 @@ export class MatMulProgram implements GPGPUProgram {
2423
userCode: string;
2524

2625
constructor(
27-
aShape: [number, number], bShape: [number, number],
28-
aOrient = MatrixOrientation.REGULAR,
29-
bOrient = MatrixOrientation.REGULAR) {
30-
const outerShapeA =
31-
(aOrient === MatrixOrientation.REGULAR) ? aShape[0] : aShape[1];
32-
const outerShapeB =
33-
(bOrient === MatrixOrientation.REGULAR) ? bShape[1] : bShape[0];
26+
aShape: [number, number], bShape: [number, number], transposeA = false,
27+
transposeB = false) {
28+
const outerShapeA = transposeA ? aShape[1] : aShape[0];
29+
const outerShapeB = transposeB ? bShape[0] : bShape[1];
30+
const sharedDim = transposeA ? aShape[0] : aShape[1];
3431
this.outputShape = [outerShapeA, outerShapeB];
3532

36-
const sharedDim =
37-
(aOrient === MatrixOrientation.REGULAR ? aShape[1] : aShape[0]);
3833
const aSnippetFromOffset = (vec4Offset: number, indexVar: string|number) =>
39-
(aOrient === MatrixOrientation.REGULAR) ?
40-
`aRow, ${indexVar} + ${vec4Offset}` :
41-
`${indexVar} + ${vec4Offset}, aRow`;
34+
transposeA ? `${indexVar} + ${vec4Offset}, aRow` :
35+
`aRow, ${indexVar} + ${vec4Offset}`;
4236
const bSnippetFromOffset = (vec4Offset: number, indexVar: string|number) =>
43-
(bOrient === MatrixOrientation.REGULAR) ?
44-
`${indexVar} + ${vec4Offset}, bCol` :
45-
`bCol, ${indexVar} + ${vec4Offset}`;
37+
transposeB ? `bCol, ${indexVar} + ${vec4Offset}` :
38+
`${indexVar} + ${vec4Offset}, bCol`;
4639

4740
const sharedDimNearestVec4 = Math.floor(sharedDim / 4) * 4;
4841
const sharedDimVec4Remainder = sharedDim % 4;

src/ops/matmul.ts

Lines changed: 30 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -20,30 +20,27 @@ import {ENV} from '../environment';
2020
import {MatrixOrientation} from '../kernels/types/matmul';
2121
import {Scalar, Tensor1D, Tensor2D} from '../tensor';
2222
import * as util from '../util';
23-
2423
import {operation} from './operation';
2524

2625
export class Ops {
2726
/**
28-
* Computes the dot product of two matrices, A * B. These must be matrices,
29-
* use matrixTimesVector and vectorTimesMatrix, dotProduct, and outerProduct
30-
* in other cases.
27+
* Computes the dot product of two matrices, A * B. These must be matrices.
28+
*
3129
* @param a First matrix in dot product operation.
3230
* @param b Second matrix in dot product operation.
33-
* @param aOrientation The MatrixOrientation of A. If using TRANSPOSED, will
34-
* compute A^T * B.
35-
* @param bOrientation The MatrixOrientation of B. If using TRANSPOSED, will
36-
* compute A * B^T.
31+
* @param transposeA If true, `a` is transposed before multiplication.
32+
* @param transposeB If true, `b` is transposed before multiplication.
3733
*/
3834
@doc({heading: 'Operations', subheading: 'Matrices'})
3935
@operation
4036
static matMul(
41-
a: Tensor2D, b: Tensor2D, aOrientation = MatrixOrientation.REGULAR,
42-
bOrientation = MatrixOrientation.REGULAR): Tensor2D {
43-
const innerShapeA =
44-
(aOrientation === MatrixOrientation.REGULAR) ? a.shape[1] : a.shape[0];
45-
const innerShapeB =
46-
(bOrientation === MatrixOrientation.REGULAR) ? b.shape[0] : b.shape[1];
37+
a: Tensor2D, b: Tensor2D, transposeA = false, transposeB = false):
38+
Tensor2D {
39+
// For backward compatibility.
40+
[transposeA, transposeB] = [enumToBool(transposeA), enumToBool(transposeB)];
41+
42+
const innerShapeA = transposeA ? a.shape[0] : a.shape[1];
43+
const innerShapeB = transposeB ? b.shape[1] : b.shape[0];
4744

4845
util.assert(
4946
a.rank === 2 && b.rank === 2,
@@ -54,30 +51,26 @@ export class Ops {
5451
innerShapeA === innerShapeB,
5552
`Error in matMul: inner shapes (${innerShapeA}) and (` +
5653
`${innerShapeB}) of Tensors with shapes ${a.shape} and ` +
57-
`${b.shape} and orientations ${MatrixOrientation[aOrientation]}` +
58-
` and ${MatrixOrientation[bOrientation]} must match.`);
54+
`${b.shape} and transposeA=${transposeA}` +
55+
` and transposeB=${transposeB} must match.`);
5956

6057
return ENV.engine.executeKernel(
61-
'MatMul', {inputs: {a, b}, args: {aOrientation, bOrientation}},
58+
'MatMul', {inputs: {a, b}, args: {transposeA, transposeB}},
6259
(dy: Tensor2D, y: Tensor2D) => {
63-
if (aOrientation === MatrixOrientation.TRANSPOSED ||
64-
bOrientation === MatrixOrientation.TRANSPOSED) {
60+
if (transposeA || transposeB) {
6561
throw new Error(
6662
`Backprop for transposed MatMul not yet implemented.`);
6763
}
6864
return {
69-
a: () => dy.matMul(
70-
b.toFloat(), MatrixOrientation.REGULAR,
71-
MatrixOrientation.TRANSPOSED) as Tensor2D,
72-
b: () => a.toFloat().matMul(
73-
dy, MatrixOrientation.TRANSPOSED,
74-
MatrixOrientation.REGULAR) as Tensor2D
65+
a: () => dy.matMul(b.toFloat(), false, true) as Tensor2D,
66+
b: () => a.toFloat().matMul(dy, true, false) as Tensor2D
7567
};
7668
}) as Tensor2D;
7769
}
7870

7971
/**
8072
* Computes the dot product of a vector and a matrix, v * B.
73+
*
8174
* @param v The vector in dot product operation.
8275
* @param matrix The matrix in dot product operation.
8376
*/
@@ -126,6 +119,7 @@ export class Ops {
126119

127120
/**
128121
* Computes the dot product of two vectors, v1 * v2.
122+
*
129123
* @param v1 The first vector in the dot product operation.
130124
* @param v2 The second vector in the dot product operation.
131125
*/
@@ -145,6 +139,7 @@ export class Ops {
145139

146140
/**
147141
* Computes the outer product of two vectors, v1 and v2.
142+
*
148143
* @param v1 The first vector in the outer product operation.
149144
* @param v2 The second vector in the dot product operation.
150145
*/
@@ -159,3 +154,13 @@ export class Ops {
159154
return v1.as2D(-1, 1).matMul(v2.as2D(1, -1));
160155
}
161156
}
157+
158+
function enumToBool(transpose: boolean|MatrixOrientation): boolean {
159+
if (transpose === MatrixOrientation.REGULAR) {
160+
return false;
161+
}
162+
if (transpose === MatrixOrientation.TRANSPOSED) {
163+
return true;
164+
}
165+
return transpose;
166+
}

0 commit comments

Comments
 (0)