@@ -98,23 +98,23 @@ export class NDArrayMathGPU extends NDArrayMath {
9898 ] ) : Array2D {
9999 const program = new SliceProgram ( size ) ;
100100 const customSetup = program . getCustomSetupFunc ( begin ) ;
101- return this . compileAndRun ( program , [ input ] , null , customSetup ) ;
101+ return this . compileAndRun ( program , [ input ] , null , customSetup ) as Array2D ;
102102 }
103103
104104 protected slice3DInternal (
105105 input : Array3D , begin : [ number , number , number ] ,
106106 size : [ number , number , number ] ) : Array3D {
107107 const program = new SliceProgram ( size ) ;
108108 const customSetup = program . getCustomSetupFunc ( begin ) ;
109- return this . compileAndRun ( program , [ input ] , null , customSetup ) ;
109+ return this . compileAndRun ( program , [ input ] , null , customSetup ) as Array3D ;
110110 }
111111
112112 protected slice4DInternal (
113113 input : Array4D , begin : [ number , number , number , number ] ,
114114 size : [ number , number , number , number ] ) : Array4D {
115115 const program = new SliceProgram ( size ) ;
116116 const customSetup = program . getCustomSetupFunc ( begin ) ;
117- return this . compileAndRun ( program , [ input ] , null , customSetup ) ;
117+ return this . compileAndRun ( program , [ input ] , null , customSetup ) as Array4D ;
118118 }
119119
120120 protected copy2DInternal (
@@ -130,33 +130,33 @@ export class NDArrayMathGPU extends NDArrayMath {
130130
131131 protected concat1DInternal ( a : Array1D , b : Array1D ) : Array1D {
132132 const program = new ConcatProgram ( a . shape , b . shape , 0 ) ;
133- return this . compileAndRun ( program , [ a , b ] ) ;
133+ return this . compileAndRun ( program , [ a , b ] ) as Array1D ;
134134 }
135135
136136 protected concat2DInternal ( a : Array2D , b : Array2D , axis : number ) : Array2D {
137137 const program = new ConcatProgram ( a . shape , b . shape , axis ) ;
138- return this . compileAndRun ( program , [ a , b ] ) ;
138+ return this . compileAndRun ( program , [ a , b ] ) as Array2D ;
139139 }
140140
141141 protected concat3DInternal ( x1 : Array3D , x2 : Array3D , axis : number ) : Array3D {
142142 const program = new ConcatProgram ( x1 . shape , x2 . shape , axis ) ;
143- return this . compileAndRun ( program , [ x1 , x2 ] ) ;
143+ return this . compileAndRun ( program , [ x1 , x2 ] ) as Array3D ;
144144 }
145145
146146 protected concat4DInternal ( x1 : Array4D , x2 : Array4D , axis : number ) : Array4D {
147147 const program = new ConcatProgram ( x1 . shape , x2 . shape , axis ) ;
148- return this . compileAndRun ( program , [ x1 , x2 ] ) ;
148+ return this . compileAndRun ( program , [ x1 , x2 ] ) as Array4D ;
149149 }
150150
151151 protected scaledArrayAddInternal < T extends NDArray > (
152- c1 : Scalar , a : T , c2 : Scalar , b : T ) {
152+ c1 : Scalar , a : T , c2 : Scalar , b : T ) : T {
153153 const program = new AddScaledMatProgram ( a . shape , b . shape ) ;
154- return this . compileAndRun < NDArray , T > ( program , [ a , b , c1 , c2 ] ) ;
154+ return this . compileAndRun < NDArray , T > ( program , [ a , b , c1 , c2 ] ) as T ;
155155 }
156156
157157 protected negInternal < T extends NDArray > ( a : T ) : T {
158158 const program = new UnaryOpProgram ( a . shape , unary_op . NEG ) ;
159- return this . compileAndRun < T , T > ( program , [ a ] ) ;
159+ return this . compileAndRun ( program , [ a ] ) as T ;
160160 }
161161
162162 private makeOutputArray < T extends NDArray > ( shape : number [ ] ) : T {
@@ -186,12 +186,12 @@ export class NDArrayMathGPU extends NDArrayMath {
186186 bOrientation : MatrixOrientation ) : Array2D {
187187 const program =
188188 new MatMulProgram ( a . shape , b . shape , aOrientation , bOrientation ) ;
189- return this . compileAndRun < Array2D , Array2D > ( program , [ a , b ] ) ;
189+ return this . compileAndRun < Array2D , Array2D > ( program , [ a , b ] ) as Array2D ;
190190 }
191191
192192 protected multiplyInternal < T extends NDArray > ( a : T , b : T ) : T {
193193 const program = new BinaryOpProgram ( binaryop_gpu . MUL , a . shape , b . shape ) ;
194- return this . compileAndRun < T , T > ( program , [ a , b ] ) ;
194+ return this . compileAndRun ( program , [ a , b ] ) as T ;
195195 }
196196
197197 protected batchNormalization3DInternal (
@@ -219,7 +219,7 @@ export class NDArrayMathGPU extends NDArrayMath {
219219 const program = new BatchNormProgram (
220220 x . shape , mean . shape , variance . shape , offsetShape , scaleShape ,
221221 varianceEpsilon ) ;
222- return this . compileAndRun ( program , inputs ) ;
222+ return this . compileAndRun ( program , inputs ) as Array3D ;
223223 }
224224
225225 protected switchDimInternal < T extends NDArray > ( a : T , newDim : number [ ] ) : T {
@@ -283,103 +283,103 @@ export class NDArrayMathGPU extends NDArrayMath {
283283
284284 protected expInternal < T extends NDArray > ( a : T ) : T {
285285 const program = new UnaryOpProgram ( a . shape , unary_op . EXP ) ;
286- return this . compileAndRun ( program , [ a ] ) ;
286+ return this . compileAndRun ( program , [ a ] ) as T ;
287287 }
288288
289289 protected logInternal < T extends NDArray > ( a : T ) : T {
290290 const program = new UnaryOpProgram ( a . shape , unary_op . LOG ) ;
291- return this . compileAndRun ( program , [ a ] ) ;
291+ return this . compileAndRun ( program , [ a ] ) as T ;
292292 }
293293
294294 protected sqrtInternal < T extends NDArray > ( a : T ) : T {
295295 const program = new UnaryOpProgram ( a . shape , unary_op . SQRT ) ;
296- return this . compileAndRun ( program , [ a ] ) ;
296+ return this . compileAndRun ( program , [ a ] ) as T ;
297297 }
298298
299299 protected reluInternal < T extends NDArray > ( a : T ) : T {
300300 const program = new UnaryOpProgram ( a . shape , unary_op . RELU ) ;
301- return this . compileAndRun ( program , [ a ] ) ;
301+ return this . compileAndRun ( program , [ a ] ) as T ;
302302 }
303303
304304 protected absInternal < T extends NDArray > ( a : T ) : T {
305305 const program = new UnaryOpProgram ( a . shape , unary_op . ABS ) ;
306- return this . compileAndRun ( program , [ a ] ) ;
306+ return this . compileAndRun ( program , [ a ] ) as T ;
307307 }
308308
309309 protected sigmoidInternal < T extends NDArray > ( a : T ) : T {
310310 const program = new UnaryOpProgram ( a . shape , unary_op . SIGMOID ) ;
311- return this . compileAndRun < T , T > ( program , [ a ] ) ;
311+ return this . compileAndRun ( program , [ a ] ) as T ;
312312 }
313313
314314 protected sinInternal < T extends NDArray > ( a : T ) : T {
315315 const program = new UnaryOpProgram ( a . shape , unary_op . SIN ) ;
316- return this . compileAndRun ( program , [ a ] ) ;
316+ return this . compileAndRun ( program , [ a ] ) as T ;
317317 }
318318
319319 protected cosInternal < T extends NDArray > ( a : T ) : T {
320320 const program = new UnaryOpProgram ( a . shape , unary_op . COS ) ;
321- return this . compileAndRun ( program , [ a ] ) ;
321+ return this . compileAndRun ( program , [ a ] ) as T ;
322322 }
323323
324324 protected tanInternal < T extends NDArray > ( a : T ) : T {
325325 const program = new UnaryOpProgram ( a . shape , unary_op . TAN ) ;
326- return this . compileAndRun ( program , [ a ] ) ;
326+ return this . compileAndRun ( program , [ a ] ) as T ;
327327 }
328328
329329 protected asinInternal < T extends NDArray > ( a : T ) : T {
330330 const program = new UnaryOpProgram ( a . shape , unary_op . ASIN ) ;
331- return this . compileAndRun ( program , [ a ] ) ;
331+ return this . compileAndRun ( program , [ a ] ) as T ;
332332 }
333333
334334 protected acosInternal < T extends NDArray > ( a : T ) : T {
335335 const program = new UnaryOpProgram ( a . shape , unary_op . ACOS ) ;
336- return this . compileAndRun ( program , [ a ] ) ;
336+ return this . compileAndRun ( program , [ a ] ) as T ;
337337 }
338338
339339 protected atanInternal < T extends NDArray > ( a : T ) : T {
340340 const program = new UnaryOpProgram ( a . shape , unary_op . ATAN ) ;
341- return this . compileAndRun ( program , [ a ] ) ;
341+ return this . compileAndRun ( program , [ a ] ) as T ;
342342 }
343343
344344 protected sinhInternal < T extends NDArray > ( a : T ) : T {
345345 const program = new UnaryOpProgram ( a . shape , unary_op . SINH ) ;
346- return this . compileAndRun ( program , [ a ] ) ;
346+ return this . compileAndRun ( program , [ a ] ) as T ;
347347 }
348348
349349 protected coshInternal < T extends NDArray > ( a : T ) : T {
350350 const program = new UnaryOpProgram ( a . shape , unary_op . COSH ) ;
351- return this . compileAndRun ( program , [ a ] ) ;
351+ return this . compileAndRun ( program , [ a ] ) as T ;
352352 }
353353
354354 protected tanhInternal < T extends NDArray > ( a : T ) : T {
355355 const program = new UnaryOpProgram ( a . shape , unary_op . TANH ) ;
356- return this . compileAndRun ( program , [ a ] ) ;
356+ return this . compileAndRun ( program , [ a ] ) as T ;
357357 }
358358
359359
360360 protected stepInternal < T extends NDArray > ( a : T ) : T {
361361 const program = new UnaryOpProgram ( a . shape , unary_op . STEP ) ;
362- return this . compileAndRun ( program , [ a ] ) ;
362+ return this . compileAndRun ( program , [ a ] ) as T ;
363363 }
364364
365365 protected conv2dInternal (
366366 x : Array3D , filter : Array4D , bias : Array1D | null ,
367367 convInfo : ConvInfo ) : Array3D {
368368 const program = new Conv2DProgram ( convInfo , bias != null ) ;
369369 const inputs = bias != null ? [ x , filter , bias ] : [ x , filter ] ;
370- return this . compileAndRun ( program , inputs ) ;
370+ return this . compileAndRun ( program , inputs ) as Array3D ;
371371 }
372372
373373 protected conv2dDerInputInternal (
374374 dy : Array3D , filter : Array4D , convInfo : ConvInfo ) : Array3D {
375375 const program = new Conv2DDerInputProgram ( convInfo ) ;
376- return this . compileAndRun ( program , [ dy , filter ] ) ;
376+ return this . compileAndRun ( program , [ dy , filter ] ) as Array3D ;
377377 }
378378
379379 protected conv2dDerFilterInternal (
380380 x : Array3D , dY : Array3D , convInfo : ConvInfo ) : Array4D {
381381 const program = new Conv2DDerWeightsProgram ( convInfo ) ;
382- return this . compileAndRun ( program , [ x , dY ] ) ;
382+ return this . compileAndRun ( program , [ x , dY ] ) as Array4D ;
383383 }
384384
385385 protected conv2dDerBiasInternal ( dY : Array3D ) : Array1D {
@@ -389,17 +389,17 @@ export class NDArrayMathGPU extends NDArrayMath {
389389
390390 protected maxPoolInternal ( x : Array3D , convInfo : ConvInfo ) : Array3D {
391391 const program = new Pool2DProgram ( convInfo , 'max' , false ) ;
392- return this . compileAndRun ( program , [ x ] ) ;
392+ return this . compileAndRun ( program , [ x ] ) as Array3D ;
393393 }
394394
395395 protected minPoolInternal ( x : Array3D , convInfo : ConvInfo ) : Array3D {
396396 const program = new Pool2DProgram ( convInfo , 'min' , false ) ;
397- return this . compileAndRun ( program , [ x ] ) ;
397+ return this . compileAndRun ( program , [ x ] ) as Array3D ;
398398 }
399399
400400 protected avgPoolInternal ( x : Array3D , convInfo : ConvInfo ) : Array3D {
401401 const program = new Pool2DProgram ( convInfo , 'avg' , false ) ;
402- return this . compileAndRun ( program , [ x ] ) ;
402+ return this . compileAndRun ( program , [ x ] ) as Array3D ;
403403 }
404404
405405 protected maxPoolBackpropInternal (
@@ -408,7 +408,7 @@ export class NDArrayMathGPU extends NDArrayMath {
408408 const maxPoolPositionsProgram =
409409 new Pool2DProgram ( convInfo , 'max' , getPositions ) ;
410410 const maxPoolPositions : Array3D =
411- this . compileAndRun ( maxPoolPositionsProgram , [ x ] ) ;
411+ this . compileAndRun ( maxPoolPositionsProgram , [ x ] ) as Array3D ;
412412
413413 const maxPoolBackPropProgram = new MaxPool2DBackpropProgram ( convInfo ) ;
414414
@@ -423,7 +423,7 @@ export class NDArrayMathGPU extends NDArrayMath {
423423 alignCorners : boolean ) : Array3D {
424424 const program =
425425 new ResizeBilinear3DProgram ( x . shape , newShape2D , alignCorners ) ;
426- return this . compileAndRun ( program , [ x ] ) ;
426+ return this . compileAndRun ( program , [ x ] ) as Array3D ;
427427 }
428428
429429 protected multinomialInternal (
0 commit comments