Skip to content
This repository was archived by the owner on Aug 15, 2019. It is now read-only.

Commit aaae240

Browse files
arthurjdamNikhil Thorat
authored andcommitted
Add and synchronize typings (#161)
* fixed some typings * mainly, use "as" syntax instead of <>
1 parent 5cb5951 commit aaae240

File tree

4 files changed

+47
-40
lines changed

4 files changed

+47
-40
lines changed

src/math/math_cpu.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -234,7 +234,7 @@ export class NDArrayMathCPU extends NDArrayMath {
234234
}
235235

236236
protected scaledArrayAddInternal<T extends NDArray>(
237-
c1: Scalar, a: T, c2: Scalar, b: T) {
237+
c1: Scalar, a: T, c2: Scalar, b: T): T {
238238
const newShape = util.assertAndGetBroadcastedShape(a.shape, b.shape);
239239
const newValues = new Float32Array(util.sizeFromShape(newShape));
240240

src/math/math_gpu.ts

Lines changed: 37 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -98,23 +98,23 @@ export class NDArrayMathGPU extends NDArrayMath {
9898
]): Array2D {
9999
const program = new SliceProgram(size);
100100
const customSetup = program.getCustomSetupFunc(begin);
101-
return this.compileAndRun(program, [input], null, customSetup);
101+
return this.compileAndRun(program, [input], null, customSetup) as Array2D;
102102
}
103103

104104
protected slice3DInternal(
105105
input: Array3D, begin: [number, number, number],
106106
size: [number, number, number]): Array3D {
107107
const program = new SliceProgram(size);
108108
const customSetup = program.getCustomSetupFunc(begin);
109-
return this.compileAndRun(program, [input], null, customSetup);
109+
return this.compileAndRun(program, [input], null, customSetup) as Array3D;
110110
}
111111

112112
protected slice4DInternal(
113113
input: Array4D, begin: [number, number, number, number],
114114
size: [number, number, number, number]): Array4D {
115115
const program = new SliceProgram(size);
116116
const customSetup = program.getCustomSetupFunc(begin);
117-
return this.compileAndRun(program, [input], null, customSetup);
117+
return this.compileAndRun(program, [input], null, customSetup) as Array4D;
118118
}
119119

120120
protected copy2DInternal(
@@ -130,33 +130,33 @@ export class NDArrayMathGPU extends NDArrayMath {
130130

131131
protected concat1DInternal(a: Array1D, b: Array1D): Array1D {
132132
const program = new ConcatProgram(a.shape, b.shape, 0);
133-
return this.compileAndRun(program, [a, b]);
133+
return this.compileAndRun(program, [a, b]) as Array1D;
134134
}
135135

136136
protected concat2DInternal(a: Array2D, b: Array2D, axis: number): Array2D {
137137
const program = new ConcatProgram(a.shape, b.shape, axis);
138-
return this.compileAndRun(program, [a, b]);
138+
return this.compileAndRun(program, [a, b]) as Array2D;
139139
}
140140

141141
protected concat3DInternal(x1: Array3D, x2: Array3D, axis: number): Array3D {
142142
const program = new ConcatProgram(x1.shape, x2.shape, axis);
143-
return this.compileAndRun(program, [x1, x2]);
143+
return this.compileAndRun(program, [x1, x2]) as Array3D;
144144
}
145145

146146
protected concat4DInternal(x1: Array4D, x2: Array4D, axis: number): Array4D {
147147
const program = new ConcatProgram(x1.shape, x2.shape, axis);
148-
return this.compileAndRun(program, [x1, x2]);
148+
return this.compileAndRun(program, [x1, x2]) as Array4D;
149149
}
150150

151151
protected scaledArrayAddInternal<T extends NDArray>(
152-
c1: Scalar, a: T, c2: Scalar, b: T) {
152+
c1: Scalar, a: T, c2: Scalar, b: T): T {
153153
const program = new AddScaledMatProgram(a.shape, b.shape);
154-
return this.compileAndRun<NDArray, T>(program, [a, b, c1, c2]);
154+
return this.compileAndRun<NDArray, T>(program, [a, b, c1, c2]) as T;
155155
}
156156

157157
protected negInternal<T extends NDArray>(a: T): T {
158158
const program = new UnaryOpProgram(a.shape, unary_op.NEG);
159-
return this.compileAndRun<T, T>(program, [a]);
159+
return this.compileAndRun(program, [a]) as T;
160160
}
161161

162162
private makeOutputArray<T extends NDArray>(shape: number[]): T {
@@ -186,12 +186,12 @@ export class NDArrayMathGPU extends NDArrayMath {
186186
bOrientation: MatrixOrientation): Array2D {
187187
const program =
188188
new MatMulProgram(a.shape, b.shape, aOrientation, bOrientation);
189-
return this.compileAndRun<Array2D, Array2D>(program, [a, b]);
189+
return this.compileAndRun<Array2D, Array2D>(program, [a, b]) as Array2D;
190190
}
191191

192192
protected multiplyInternal<T extends NDArray>(a: T, b: T): T {
193193
const program = new BinaryOpProgram(binaryop_gpu.MUL, a.shape, b.shape);
194-
return this.compileAndRun<T, T>(program, [a, b]);
194+
return this.compileAndRun(program, [a, b]) as T;
195195
}
196196

197197
protected batchNormalization3DInternal(
@@ -219,7 +219,7 @@ export class NDArrayMathGPU extends NDArrayMath {
219219
const program = new BatchNormProgram(
220220
x.shape, mean.shape, variance.shape, offsetShape, scaleShape,
221221
varianceEpsilon);
222-
return this.compileAndRun(program, inputs);
222+
return this.compileAndRun(program, inputs) as Array3D;
223223
}
224224

225225
protected switchDimInternal<T extends NDArray>(a: T, newDim: number[]): T {
@@ -283,103 +283,103 @@ export class NDArrayMathGPU extends NDArrayMath {
283283

284284
protected expInternal<T extends NDArray>(a: T): T {
285285
const program = new UnaryOpProgram(a.shape, unary_op.EXP);
286-
return this.compileAndRun(program, [a]);
286+
return this.compileAndRun(program, [a]) as T;
287287
}
288288

289289
protected logInternal<T extends NDArray>(a: T): T {
290290
const program = new UnaryOpProgram(a.shape, unary_op.LOG);
291-
return this.compileAndRun(program, [a]);
291+
return this.compileAndRun(program, [a]) as T;
292292
}
293293

294294
protected sqrtInternal<T extends NDArray>(a: T): T {
295295
const program = new UnaryOpProgram(a.shape, unary_op.SQRT);
296-
return this.compileAndRun(program, [a]);
296+
return this.compileAndRun(program, [a]) as T;
297297
}
298298

299299
protected reluInternal<T extends NDArray>(a: T): T {
300300
const program = new UnaryOpProgram(a.shape, unary_op.RELU);
301-
return this.compileAndRun(program, [a]);
301+
return this.compileAndRun(program, [a]) as T;
302302
}
303303

304304
protected absInternal<T extends NDArray>(a: T): T {
305305
const program = new UnaryOpProgram(a.shape, unary_op.ABS);
306-
return this.compileAndRun(program, [a]);
306+
return this.compileAndRun(program, [a]) as T;
307307
}
308308

309309
protected sigmoidInternal<T extends NDArray>(a: T): T {
310310
const program = new UnaryOpProgram(a.shape, unary_op.SIGMOID);
311-
return this.compileAndRun<T, T>(program, [a]);
311+
return this.compileAndRun(program, [a]) as T;
312312
}
313313

314314
protected sinInternal<T extends NDArray>(a: T): T {
315315
const program = new UnaryOpProgram(a.shape, unary_op.SIN);
316-
return this.compileAndRun(program, [a]);
316+
return this.compileAndRun(program, [a]) as T;
317317
}
318318

319319
protected cosInternal<T extends NDArray>(a: T): T {
320320
const program = new UnaryOpProgram(a.shape, unary_op.COS);
321-
return this.compileAndRun(program, [a]);
321+
return this.compileAndRun(program, [a]) as T;
322322
}
323323

324324
protected tanInternal<T extends NDArray>(a: T): T {
325325
const program = new UnaryOpProgram(a.shape, unary_op.TAN);
326-
return this.compileAndRun(program, [a]);
326+
return this.compileAndRun(program, [a]) as T;
327327
}
328328

329329
protected asinInternal<T extends NDArray>(a: T): T {
330330
const program = new UnaryOpProgram(a.shape, unary_op.ASIN);
331-
return this.compileAndRun(program, [a]);
331+
return this.compileAndRun(program, [a]) as T;
332332
}
333333

334334
protected acosInternal<T extends NDArray>(a: T): T {
335335
const program = new UnaryOpProgram(a.shape, unary_op.ACOS);
336-
return this.compileAndRun(program, [a]);
336+
return this.compileAndRun(program, [a]) as T;
337337
}
338338

339339
protected atanInternal<T extends NDArray>(a: T): T {
340340
const program = new UnaryOpProgram(a.shape, unary_op.ATAN);
341-
return this.compileAndRun(program, [a]);
341+
return this.compileAndRun(program, [a]) as T;
342342
}
343343

344344
protected sinhInternal<T extends NDArray>(a: T): T {
345345
const program = new UnaryOpProgram(a.shape, unary_op.SINH);
346-
return this.compileAndRun(program, [a]);
346+
return this.compileAndRun(program, [a]) as T;
347347
}
348348

349349
protected coshInternal<T extends NDArray>(a: T): T {
350350
const program = new UnaryOpProgram(a.shape, unary_op.COSH);
351-
return this.compileAndRun(program, [a]);
351+
return this.compileAndRun(program, [a]) as T;
352352
}
353353

354354
protected tanhInternal<T extends NDArray>(a: T): T {
355355
const program = new UnaryOpProgram(a.shape, unary_op.TANH);
356-
return this.compileAndRun(program, [a]);
356+
return this.compileAndRun(program, [a]) as T;
357357
}
358358

359359

360360
protected stepInternal<T extends NDArray>(a: T): T {
361361
const program = new UnaryOpProgram(a.shape, unary_op.STEP);
362-
return this.compileAndRun(program, [a]);
362+
return this.compileAndRun(program, [a]) as T;
363363
}
364364

365365
protected conv2dInternal(
366366
x: Array3D, filter: Array4D, bias: Array1D|null,
367367
convInfo: ConvInfo): Array3D {
368368
const program = new Conv2DProgram(convInfo, bias != null);
369369
const inputs = bias != null ? [x, filter, bias] : [x, filter];
370-
return this.compileAndRun(program, inputs);
370+
return this.compileAndRun(program, inputs) as Array3D;
371371
}
372372

373373
protected conv2dDerInputInternal(
374374
dy: Array3D, filter: Array4D, convInfo: ConvInfo): Array3D {
375375
const program = new Conv2DDerInputProgram(convInfo);
376-
return this.compileAndRun(program, [dy, filter]);
376+
return this.compileAndRun(program, [dy, filter]) as Array3D;
377377
}
378378

379379
protected conv2dDerFilterInternal(
380380
x: Array3D, dY: Array3D, convInfo: ConvInfo): Array4D {
381381
const program = new Conv2DDerWeightsProgram(convInfo);
382-
return this.compileAndRun(program, [x, dY]);
382+
return this.compileAndRun(program, [x, dY]) as Array4D;
383383
}
384384

385385
protected conv2dDerBiasInternal(dY: Array3D): Array1D {
@@ -389,17 +389,17 @@ export class NDArrayMathGPU extends NDArrayMath {
389389

390390
protected maxPoolInternal(x: Array3D, convInfo: ConvInfo): Array3D {
391391
const program = new Pool2DProgram(convInfo, 'max', false);
392-
return this.compileAndRun(program, [x]);
392+
return this.compileAndRun(program, [x]) as Array3D;
393393
}
394394

395395
protected minPoolInternal(x: Array3D, convInfo: ConvInfo): Array3D {
396396
const program = new Pool2DProgram(convInfo, 'min', false);
397-
return this.compileAndRun(program, [x]);
397+
return this.compileAndRun(program, [x]) as Array3D;
398398
}
399399

400400
protected avgPoolInternal(x: Array3D, convInfo: ConvInfo): Array3D {
401401
const program = new Pool2DProgram(convInfo, 'avg', false);
402-
return this.compileAndRun(program, [x]);
402+
return this.compileAndRun(program, [x]) as Array3D;
403403
}
404404

405405
protected maxPoolBackpropInternal(
@@ -408,7 +408,7 @@ export class NDArrayMathGPU extends NDArrayMath {
408408
const maxPoolPositionsProgram =
409409
new Pool2DProgram(convInfo, 'max', getPositions);
410410
const maxPoolPositions: Array3D =
411-
this.compileAndRun(maxPoolPositionsProgram, [x]);
411+
this.compileAndRun(maxPoolPositionsProgram, [x]) as Array3D;
412412

413413
const maxPoolBackPropProgram = new MaxPool2DBackpropProgram(convInfo);
414414

@@ -423,7 +423,7 @@ export class NDArrayMathGPU extends NDArrayMath {
423423
alignCorners: boolean): Array3D {
424424
const program =
425425
new ResizeBilinear3DProgram(x.shape, newShape2D, alignCorners);
426-
return this.compileAndRun(program, [x]);
426+
return this.compileAndRun(program, [x]) as Array3D;
427427
}
428428

429429
protected multinomialInternal(

src/math/ndarray.ts

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -658,5 +658,7 @@ export class Array4D extends NDArray {
658658
type ArrayData = Float32Array|number[]|number[][]|number[][][]|number[][][][];
659659

660660
function toTypedArray(a: ArrayData): Float32Array {
661-
return (a instanceof Float32Array) ? a : new Float32Array(util.flatten(a));
661+
return (a instanceof Float32Array) ?
662+
// tslint:disable-next-line:no-any
663+
a : new Float32Array(util.flatten(a as any[]));
662664
}

src/util.ts

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,12 @@ export function flatten(arr: any[], ret?: number[]): number[] {
106106
}
107107

108108
export type ArrayData =
109-
number | number[] | number[][] | number[][][] | number[][][][];
109+
Float32Array |
110+
number |
111+
number[] |
112+
number[][] |
113+
number[][][] |
114+
number[][][][];
110115

111116
export function inferShape(arr: ArrayData): number[] {
112117
const shape: number[] = [];

0 commit comments

Comments
 (0)