@@ -17,6 +17,7 @@ import {NDArray} from '../ndarray';
1717
1818import { GPGPUContext } from './gpgpu_context' ;
1919import * as shader_compiler from './shader_compiler' ;
20+ import { ShapeInfo } from './shader_compiler' ;
2021import * as util from '../../util' ;
2122
2223export interface GPGPUProgram {
@@ -26,47 +27,47 @@ export interface GPGPUProgram {
2627 userCode : string ;
2728}
2829
29- export interface GPGPUBinary < T extends NDArray , K extends NDArray > {
30+ export interface GPGPUBinary {
3031 webGLProgram : WebGLProgram ;
3132 program : GPGPUProgram ;
3233 gpgpu : GPGPUContext ;
3334 source : string ;
34- inputs : T [ ] ;
35- output : K ;
35+ inShapeInfos : ShapeInfo [ ] ;
36+ outShapeInfo : ShapeInfo ;
3637}
3738
3839export function compileProgram < T extends NDArray , K extends NDArray > (
3940 gpgpu : GPGPUContext , program : GPGPUProgram , inputs : T [ ] ,
40- output : K ) : GPGPUBinary < T , K > {
41+ output : K ) : GPGPUBinary {
4142 const userCode = program . userCode ;
42- const programInputs = program . variableNames . map ( ( x , i ) => {
43- const fullShape = {
44- shape : inputs [ i ] . shape ,
43+ const inputInfos = program . variableNames . map ( ( x , i ) => {
44+ const shapeInfo = {
45+ logicalShape : inputs [ i ] . shape ,
4546 texShape : inputs [ i ] . getTextureShapeRC ( )
4647 } ;
47- return { name : x , fullShape } ;
48+ return { name : x , shapeInfo } ;
4849 } ) ;
49-
50- const outFullShape = {
51- shape : output . shape ,
50+ const inShapeInfos = inputInfos . map ( x => x . shapeInfo ) ;
51+ const outShapeInfo = {
52+ logicalShape : output . shape ,
5253 texShape : output . getTextureShapeRC ( )
5354 } ;
54- const source = shader_compiler . makeShader ( programInputs , outFullShape ,
55+ const source = shader_compiler . makeShader ( inputInfos , outShapeInfo ,
5556 userCode ) ;
5657 return {
5758 program,
5859 source,
5960 webGLProgram : gpgpu . createProgram ( source ) ,
6061 gpgpu,
61- inputs ,
62- output
62+ inShapeInfos ,
63+ outShapeInfo
6364 } ;
6465}
6566
66- function validateBinaryAndProgram ( aArrays : NDArray [ ] , bArrays : NDArray [ ] ) {
67- aArrays . forEach ( ( a , i ) => {
68- const shapeA = a . shape ;
69- const texShapeA = a . getTextureShapeRC ( ) ;
67+ function validateBinaryAndProgram ( shapeInfos : ShapeInfo [ ] , bArrays : NDArray [ ] ) {
68+ shapeInfos . forEach ( ( s , i ) => {
69+ const shapeA = s . logicalShape ;
70+ const texShapeA = s . texShape ;
7071 const shapeB = bArrays [ i ] . shape ;
7172 const texShapeB = bArrays [ i ] . getTextureShapeRC ( ) ;
7273
@@ -82,17 +83,10 @@ function validateBinaryAndProgram(aArrays: NDArray[], bArrays: NDArray[]) {
8283}
8384
8485export function runProgram < T extends NDArray , K extends NDArray > (
85- binary : GPGPUBinary < T , K > , inputs ?: T [ ] , output ?: K ) : void {
86- if ( inputs == null ) {
87- inputs = binary . inputs ;
88- } else {
89- validateBinaryAndProgram ( binary . inputs , inputs ) ;
90- }
91- if ( output == null ) {
92- output = binary . output ;
93- } else {
94- validateBinaryAndProgram ( [ binary . output ] , [ output ] ) ;
95- }
86+ binary : GPGPUBinary , inputs : T [ ] , output : K ) : void {
87+ validateBinaryAndProgram ( binary . inShapeInfos , inputs ) ;
88+ validateBinaryAndProgram ( [ binary . outShapeInfo ] , [ output ] ) ;
89+
9690 const outTex = output . getTexture ( ) ;
9791 const outTexShape = output . getTextureShapeRC ( ) ;
9892 const gpgpu = binary . gpgpu ;
0 commit comments