@@ -20,30 +20,27 @@ import {ENV} from '../environment';
2020import { MatrixOrientation } from '../kernels/types/matmul' ;
2121import { Scalar , Tensor1D , Tensor2D } from '../tensor' ;
2222import * as util from '../util' ;
23-
2423import { operation } from './operation' ;
2524
2625export class Ops {
2726 /**
28- * Computes the dot product of two matrices, A * B. These must be matrices,
29- * use matrixTimesVector and vectorTimesMatrix, dotProduct, and outerProduct
30- * in other cases.
27+ * Computes the dot product of two matrices, A * B. These must be matrices.
28+ *
3129 * @param a First matrix in dot product operation.
3230 * @param b Second matrix in dot product operation.
33- * @param aOrientation The MatrixOrientation of A. If using TRANSPOSED, will
34- * compute A^T * B.
35- * @param bOrientation The MatrixOrientation of B. If using TRANSPOSED, will
36- * compute A * B^T.
31+ * @param transposeA If true, `a` is transposed before multiplication.
32+ * @param transposeB If true, `b` is transposed before multiplication.
3733 */
3834 @doc ( { heading : 'Operations' , subheading : 'Matrices' } )
3935 @operation
4036 static matMul (
41- a : Tensor2D , b : Tensor2D , aOrientation = MatrixOrientation . REGULAR ,
42- bOrientation = MatrixOrientation . REGULAR ) : Tensor2D {
43- const innerShapeA =
44- ( aOrientation === MatrixOrientation . REGULAR ) ? a . shape [ 1 ] : a . shape [ 0 ] ;
45- const innerShapeB =
46- ( bOrientation === MatrixOrientation . REGULAR ) ? b . shape [ 0 ] : b . shape [ 1 ] ;
37+ a : Tensor2D , b : Tensor2D , transposeA = false , transposeB = false ) :
38+ Tensor2D {
39+ // For backward compatibility.
40+ [ transposeA , transposeB ] = [ enumToBool ( transposeA ) , enumToBool ( transposeB ) ] ;
41+
42+ const innerShapeA = transposeA ? a . shape [ 0 ] : a . shape [ 1 ] ;
43+ const innerShapeB = transposeB ? b . shape [ 1 ] : b . shape [ 0 ] ;
4744
4845 util . assert (
4946 a . rank === 2 && b . rank === 2 ,
@@ -54,30 +51,26 @@ export class Ops {
5451 innerShapeA === innerShapeB ,
5552 `Error in matMul: inner shapes (${ innerShapeA } ) and (` +
5653 `${ innerShapeB } ) of Tensors with shapes ${ a . shape } and ` +
57- `${ b . shape } and orientations ${ MatrixOrientation [ aOrientation ] } ` +
58- ` and ${ MatrixOrientation [ bOrientation ] } must match.` ) ;
54+ `${ b . shape } and transposeA= ${ transposeA } ` +
55+ ` and transposeB= ${ transposeB } must match.` ) ;
5956
6057 return ENV . engine . executeKernel (
61- 'MatMul' , { inputs : { a, b} , args : { aOrientation , bOrientation } } ,
58+ 'MatMul' , { inputs : { a, b} , args : { transposeA , transposeB } } ,
6259 ( dy : Tensor2D , y : Tensor2D ) => {
63- if ( aOrientation === MatrixOrientation . TRANSPOSED ||
64- bOrientation === MatrixOrientation . TRANSPOSED ) {
60+ if ( transposeA || transposeB ) {
6561 throw new Error (
6662 `Backprop for transposed MatMul not yet implemented.` ) ;
6763 }
6864 return {
69- a : ( ) => dy . matMul (
70- b . toFloat ( ) , MatrixOrientation . REGULAR ,
71- MatrixOrientation . TRANSPOSED ) as Tensor2D ,
72- b : ( ) => a . toFloat ( ) . matMul (
73- dy , MatrixOrientation . TRANSPOSED ,
74- MatrixOrientation . REGULAR ) as Tensor2D
65+ a : ( ) => dy . matMul ( b . toFloat ( ) , false , true ) as Tensor2D ,
66+ b : ( ) => a . toFloat ( ) . matMul ( dy , true , false ) as Tensor2D
7567 } ;
7668 } ) as Tensor2D ;
7769 }
7870
7971 /**
8072 * Computes the dot product of a vector and a matrix, v * B.
73+ *
8174 * @param v The vector in dot product operation.
8275 * @param matrix The matrix in dot product operation.
8376 */
@@ -126,6 +119,7 @@ export class Ops {
126119
127120 /**
128121 * Computes the dot product of two vectors, v1 * v2.
122+ *
129123 * @param v1 The first vector in the dot product operation.
130124 * @param v2 The second vector in the dot product operation.
131125 */
@@ -145,6 +139,7 @@ export class Ops {
145139
146140 /**
147141 * Computes the outer product of two vectors, v1 and v2.
142+ *
148143 * @param v1 The first vector in the outer product operation.
149144 * @param v2 The second vector in the dot product operation.
150145 */
@@ -159,3 +154,13 @@ export class Ops {
159154 return v1 . as2D ( - 1 , 1 ) . matMul ( v2 . as2D ( 1 , - 1 ) ) ;
160155 }
161156}
157+
158+ function enumToBool ( transpose : boolean | MatrixOrientation ) : boolean {
159+ if ( transpose === MatrixOrientation . REGULAR ) {
160+ return false ;
161+ }
162+ if ( transpose === MatrixOrientation . TRANSPOSED ) {
163+ return true ;
164+ }
165+ return transpose ;
166+ }
0 commit comments