@@ -32,10 +32,7 @@ var shape2strides = require( '@stdlib/ndarray/base/shape2strides' );
3232var isnanf = require ( '@stdlib/math/base/assert/is-nanf' ) ;
3333var format = require ( '@stdlib/string/format' ) ;
3434var tryRequire = require ( '@stdlib/utils/try-require' ) ;
35-
36- // var sgemm = require( '@stdlib/blas/base/sgemm' ).ndarray;
37- var sgemm = require ( '@stdlib/utils/noop' ) ; // FIXME: remove once `sgemm` merged
38-
35+ var dgemm = require ( '@stdlib/blas/base/dgemm' ) . ndarray ;
3936var pkg = require ( './../package.json' ) . name ;
4037
4138
@@ -46,7 +43,10 @@ var tfnode = tryRequire( resolve( __dirname, '..', 'node_modules', '@tensorflow/
4643var opts = {
4744 'skip' : ( tf instanceof Error )
4845} ;
49- var OPTS = {
46+ var SOPTS = {
47+ 'dtype' : 'float64'
48+ } ;
49+ var TOPTS = {
5050 'dtype' : 'float32'
5151} ;
5252
@@ -73,9 +73,9 @@ function createBenchmark1( shapeA, orderA, shapeB, orderB, shapeC, orderC ) {
7373 var B ;
7474 var C ;
7575
76- A = discreteUniform ( numel ( shapeA ) , 0 , 10 , OPTS ) ;
77- B = discreteUniform ( numel ( shapeB ) , 0 , 10 , OPTS ) ;
78- C = discreteUniform ( numel ( shapeC ) , 0 , 10 , OPTS ) ;
76+ A = discreteUniform ( numel ( shapeA ) , 0 , 10 , SOPTS ) ;
77+ B = discreteUniform ( numel ( shapeB ) , 0 , 10 , SOPTS ) ;
78+ C = discreteUniform ( numel ( shapeC ) , 0 , 10 , SOPTS ) ;
7979
8080 sa = shape2strides ( shapeA , orderA ) ;
8181 sb = shape2strides ( shapeB , orderB ) ;
@@ -94,7 +94,7 @@ function createBenchmark1( shapeA, orderA, shapeB, orderB, shapeC, orderC ) {
9494
9595 b . tic ( ) ;
9696 for ( i = 0 ; i < b . iterations ; i ++ ) {
97- sgemm ( 'no-transpose' , 'no-transpose' , shapeA [ 0 ] , shapeC [ 1 ] , shapeB [ 0 ] , 0.5 , A , sa [ 0 ] , sa [ 1 ] , 0 , B , sb [ 0 ] , sb [ 1 ] , 0 , 2.0 , C , sc [ 0 ] , sc [ 1 ] , 0 ) ;
97+ dgemm ( 'no-transpose' , 'no-transpose' , shapeA [ 0 ] , shapeC [ 1 ] , shapeB [ 0 ] , 0.5 , A , sa [ 0 ] , sa [ 1 ] , 0 , B , sb [ 0 ] , sb [ 1 ] , 0 , 2.0 , C , sc [ 0 ] , sc [ 1 ] , 0 ) ;
9898 if ( isnanf ( C [ i % C . length ] ) ) {
9999 b . fail ( 'should not return NaN' ) ;
100100 }
@@ -122,9 +122,9 @@ function createBenchmark2( shapeA, shapeB, shapeC ) {
122122 var bbuf ;
123123 var cbuf ;
124124
125- abuf = discreteUniform ( numel ( shapeA ) , 0 , 10 , OPTS ) ;
126- bbuf = discreteUniform ( numel ( shapeB ) , 0 , 10 , OPTS ) ;
127- cbuf = discreteUniform ( numel ( shapeC ) , 0 , 10 , OPTS ) ;
125+ abuf = discreteUniform ( numel ( shapeA ) , 0 , 10 , TOPTS ) ;
126+ bbuf = discreteUniform ( numel ( shapeB ) , 0 , 10 , TOPTS ) ;
127+ cbuf = discreteUniform ( numel ( shapeC ) , 0 , 10 , TOPTS ) ;
128128
129129 return benchmark ;
130130
@@ -144,9 +144,9 @@ function createBenchmark2( shapeA, shapeB, shapeC ) {
144144
145145 tf . setBackend ( 'cpu' ) ;
146146
147- A = tf . tensor ( abuf , shapeA , OPTS . dtype ) ;
148- B = tf . tensor ( bbuf , shapeB , OPTS . dtype ) ;
149- C = tf . tensor ( cbuf , shapeC , OPTS . dtype ) ;
147+ A = tf . tensor ( abuf , shapeA , TOPTS . dtype ) ;
148+ B = tf . tensor ( bbuf , shapeB , TOPTS . dtype ) ;
149+ C = tf . tensor ( cbuf , shapeC , TOPTS . dtype ) ;
150150
151151 b . tic ( ) ;
152152 for ( i = 0 ; i < b . iterations ; i ++ ) {
@@ -184,9 +184,9 @@ function createBenchmark3( shapeA, shapeB, shapeC ) {
184184 var bbuf ;
185185 var cbuf ;
186186
187- abuf = discreteUniform ( numel ( shapeA ) , 0 , 10 , OPTS ) ;
188- bbuf = discreteUniform ( numel ( shapeB ) , 0 , 10 , OPTS ) ;
189- cbuf = discreteUniform ( numel ( shapeC ) , 0 , 10 , OPTS ) ;
187+ abuf = discreteUniform ( numel ( shapeA ) , 0 , 10 , TOPTS ) ;
188+ bbuf = discreteUniform ( numel ( shapeB ) , 0 , 10 , TOPTS ) ;
189+ cbuf = discreteUniform ( numel ( shapeC ) , 0 , 10 , TOPTS ) ;
190190
191191 return benchmark ;
192192
@@ -206,9 +206,9 @@ function createBenchmark3( shapeA, shapeB, shapeC ) {
206206
207207 tfnode . setBackend ( 'tensorflow' ) ;
208208
209- A = tfnode . tensor ( abuf , shapeA , OPTS . dtype ) ;
210- B = tfnode . tensor ( bbuf , shapeB , OPTS . dtype ) ;
211- C = tfnode . tensor ( cbuf , shapeC , OPTS . dtype ) ;
209+ A = tfnode . tensor ( abuf , shapeA , TOPTS . dtype ) ;
210+ B = tfnode . tensor ( bbuf , shapeB , TOPTS . dtype ) ;
211+ C = tfnode . tensor ( cbuf , shapeC , TOPTS . dtype ) ;
212212
213213 b . tic ( ) ;
214214 for ( i = 0 ; i < b . iterations ; i ++ ) {
@@ -250,7 +250,7 @@ function main() {
250250 var i ;
251251
252252 min = 1 ; // 10^min
253- max = 6 ; // 10^max
253+ max = 5 ; // 10^max
254254
255255 for ( i = min ; i <= max ; i ++ ) {
256256 N = floor ( pow ( pow ( 10 , i ) , 1.0 / 2.0 ) ) ;
@@ -265,27 +265,27 @@ function main() {
265265 'row-major'
266266 ] ;
267267 f = createBenchmark1 ( shapes [ 0 ] , orders [ 0 ] , shapes [ 1 ] , orders [ 1 ] , shapes [ 2 ] , orders [ 2 ] ) ;
268- bench ( format ( '%s::stdlib:blas/base/sgemm :dtype=%s,orders=(%s),size=%d,shapes={(%s),(%s),(%s)}' , pkg , OPTS . dtype , orders . join ( ',' ) , numel ( shapes [ 2 ] ) , shapes [ 0 ] . join ( ',' ) , shapes [ 1 ] . join ( ',' ) , shapes [ 2 ] . join ( ',' ) ) , f ) ;
268+ bench ( format ( '%s::stdlib:blas/base/dgemm :dtype=%s,orders=(%s),size=%d,shapes={(%s),(%s),(%s)}' , pkg , SOPTS . dtype , orders . join ( ',' ) , numel ( shapes [ 2 ] ) , shapes [ 0 ] . join ( ',' ) , shapes [ 1 ] . join ( ',' ) , shapes [ 2 ] . join ( ',' ) ) , f ) ;
269269
270270 f = createBenchmark2 ( shapes [ 0 ] , shapes [ 1 ] , shapes [ 2 ] ) ;
271- bench ( format ( '%s::tfjs:matmul:dtype=%s,size=%d,shapes={(%s),(%s),(%s)}' , pkg , OPTS . dtype , numel ( shapes [ 2 ] ) , shapes [ 0 ] . join ( ',' ) , shapes [ 1 ] . join ( ',' ) , shapes [ 2 ] . join ( ',' ) ) , opts , f ) ;
271+ bench ( format ( '%s::tfjs:matmul:dtype=%s,size=%d,shapes={(%s),(%s),(%s)}' , pkg , TOPTS . dtype , numel ( shapes [ 2 ] ) , shapes [ 0 ] . join ( ',' ) , shapes [ 1 ] . join ( ',' ) , shapes [ 2 ] . join ( ',' ) ) , opts , f ) ;
272272
273273 f = createBenchmark3 ( shapes [ 0 ] , shapes [ 1 ] , shapes [ 2 ] ) ;
274- bench ( format ( '%s::tfjs-node:matmul:dtype=%s,size=%d,shapes={(%s),(%s),(%s)}' , pkg , OPTS . dtype , numel ( shapes [ 2 ] ) , shapes [ 0 ] . join ( ',' ) , shapes [ 1 ] . join ( ',' ) , shapes [ 2 ] . join ( ',' ) ) , opts , f ) ;
274+ bench ( format ( '%s::tfjs-node:matmul:dtype=%s,size=%d,shapes={(%s),(%s),(%s)}' , pkg , TOPTS . dtype , numel ( shapes [ 2 ] ) , shapes [ 0 ] . join ( ',' ) , shapes [ 1 ] . join ( ',' ) , shapes [ 2 ] . join ( ',' ) ) , opts , f ) ;
275275
276276 orders = [
277277 'row-major' ,
278278 'column-major' ,
279279 'row-major'
280280 ] ;
281281 f = createBenchmark1 ( shapes [ 0 ] , orders [ 0 ] , shapes [ 1 ] , orders [ 1 ] , shapes [ 2 ] , orders [ 2 ] ) ;
282- bench ( format ( '%s::stdlib:blas/base/sgemm :dtype=%s,orders=(%s),size=%d,shapes={(%s),(%s),(%s)}' , pkg , OPTS . dtype , orders . join ( ',' ) , numel ( shapes [ 2 ] ) , shapes [ 0 ] . join ( ',' ) , shapes [ 1 ] . join ( ',' ) , shapes [ 2 ] . join ( ',' ) ) , f ) ;
282+ bench ( format ( '%s::stdlib:blas/base/dgemm :dtype=%s,orders=(%s),size=%d,shapes={(%s),(%s),(%s)}' , pkg , SOPTS . dtype , orders . join ( ',' ) , numel ( shapes [ 2 ] ) , shapes [ 0 ] . join ( ',' ) , shapes [ 1 ] . join ( ',' ) , shapes [ 2 ] . join ( ',' ) ) , f ) ;
283283
284284 f = createBenchmark2 ( shapes [ 0 ] , shapes [ 1 ] , shapes [ 2 ] ) ;
285- bench ( format ( '%s::tfjs:matmul:dtype=%s,size=%d,shapes={(%s),(%s),(%s)}' , pkg , OPTS . dtype , numel ( shapes [ 2 ] ) , shapes [ 0 ] . join ( ',' ) , shapes [ 1 ] . join ( ',' ) , shapes [ 2 ] . join ( ',' ) ) , opts , f ) ;
285+ bench ( format ( '%s::tfjs:matmul:dtype=%s,size=%d,shapes={(%s),(%s),(%s)}' , pkg , TOPTS . dtype , numel ( shapes [ 2 ] ) , shapes [ 0 ] . join ( ',' ) , shapes [ 1 ] . join ( ',' ) , shapes [ 2 ] . join ( ',' ) ) , opts , f ) ;
286286
287287 f = createBenchmark3 ( shapes [ 0 ] , shapes [ 1 ] , shapes [ 2 ] ) ;
288- bench ( format ( '%s::tfjs-node:matmul:dtype=%s,size=%d,shapes={(%s),(%s),(%s)}' , pkg , OPTS . dtype , numel ( shapes [ 2 ] ) , shapes [ 0 ] . join ( ',' ) , shapes [ 1 ] . join ( ',' ) , shapes [ 2 ] . join ( ',' ) ) , opts , f ) ;
288+ bench ( format ( '%s::tfjs-node:matmul:dtype=%s,size=%d,shapes={(%s),(%s),(%s)}' , pkg , TOPTS . dtype , numel ( shapes [ 2 ] ) , shapes [ 0 ] . join ( ',' ) , shapes [ 1 ] . join ( ',' ) , shapes [ 2 ] . join ( ',' ) ) , opts , f ) ;
289289 }
290290}
291291
0 commit comments