@@ -40,6 +40,12 @@ export class ConvOps {
4040 * - For more info, see this guide:
4141 * [https://www.tensorflow.org/api_guides/python/nn#Convolution](
4242 * https://www.tensorflow.org/api_guides/python/nn#Convolution)
43+ * @param dataFormat An optional string from "NWC", "NCW". Defaults to "NWC",
44+ * the data is stored in the order of [batch, in_width, in_channels]. Only
45+ * "NWC" is currently supported.
46+ * @param dilation The dilation rate in which we sample input values in
47+ * atrous convolution. Defaults to `1`. If it is greater than 1, then
48+ * stride must be `1`.
4349 * @param dimRoundingMode The rounding mode used when computing output
4450 * dimensions if pad is a number. If none is provided, it will not round
4551 * and error if the output is of fractional size.
@@ -48,6 +54,7 @@ export class ConvOps {
4854 @operation
4955 static conv1d < T extends Tensor2D | Tensor3D > (
5056 input : T , filter : Tensor3D , stride : number , pad : 'valid' | 'same' | number ,
57+ dataFormat : 'NWC' | 'NCW' = 'NWC' , dilation = 1 ,
5158 dimRoundingMode ?: 'floor' | 'round' | 'ceil' ) : T {
5259 let input3D = input as Tensor3D ;
5360 let reshapedTo3D = false ;
@@ -74,15 +81,27 @@ export class ConvOps {
7481 input3D . shape [ 2 ] === filter . shape [ 1 ] ,
7582 `Error in conv1d: depth of input (${ input3D . shape [ 2 ] } ) must match ` +
7683 `input depth for filter ${ filter . shape [ 1 ] } .` ) ;
84+ util . assert (
85+ eitherStridesOrDilationsAreOne ( stride , dilation ) ,
86+ 'Error in conv1D: Either stride or dilation must be 1.' +
87+ `Got stride ${ stride } and dilation '${ dilation } '` ) ;
88+ util . assert (
89+ dataFormat === 'NWC' ,
90+ `Error in conv1d: got dataFormat of ${
91+ dataFormat } but only NWC is currently supported.`) ;
7792
7893 const filter4D =
7994 filter . as4D ( 1 , filter . shape [ 0 ] , filter . shape [ 1 ] , filter . shape [ 2 ] ) ;
8095 const input4D =
8196 input3D . as4D ( input3D . shape [ 0 ] , 1 , input3D . shape [ 1 ] , input3D . shape [ 2 ] ) ;
8297 const strides : [ number , number ] = [ 1 , stride ] ;
98+ const dilations : [ number , number ] = [ 1 , dilation ] ;
99+
100+ const conv2dDataFormat = 'NHWC' ;
83101
84- const res =
85- ConvOps . conv2d ( input4D , filter4D , strides , pad , dimRoundingMode ) ;
102+ const res = ConvOps . conv2d (
103+ input4D , filter4D , strides , pad , conv2dDataFormat , dilations ,
104+ dimRoundingMode ) ;
86105
87106 if ( reshapedTo3D ) {
88107 return res . as2D ( res . shape [ 2 ] , res . shape [ 3 ] ) as T ;
@@ -108,6 +127,15 @@ export class ConvOps {
108127 * - For more info, see this guide:
109128 * [https://www.tensorflow.org/api_guides/python/nn#Convolution](
110129 * https://www.tensorflow.org/api_guides/python/nn#Convolution)
130+ * @param dataFormat: An optional string from: "NHWC", "NCHW". Defaults to
131+ * "NHWC". Specify the data format of the input and output data. With the
132+ * default format "NHWC", the data is stored in the order of: [batch,
133+ * height, width, channels]. Only "NHWC" is currently supported.
134+ * @param dilations The dilation rates: `[dilationHeight, dilationWidth]`
135+ * in which we sample input values across the height and width dimensions
136+ * in atrous convolution. Defaults to `[1, 1]`. If `dilations` is a single
137+ * number, then `dilationHeight == dilationWidth`. If it is greater than
138+ * 1, then all values of `strides` must be 1.
111139 * @param dimRoundingMode The rounding mode used when computing output
112140 * dimensions if pad is a number. If none is provided, it will not round
113141 * and error if the output is of fractional size.
@@ -116,7 +144,9 @@ export class ConvOps {
116144 @operation
117145 static conv2d < T extends Tensor3D | Tensor4D > (
118146 x : T , filter : Tensor4D , strides : [ number , number ] | number ,
119- pad : 'valid' | 'same' | number , dimRoundingMode ?: 'floor' | 'round' | 'ceil' ) : T {
147+ pad : 'valid' | 'same' | number , dataFormat : 'NHWC' | 'NCHW' = 'NHWC' ,
148+ dilations : [ number , number ] | number = [ 1 , 1 ] ,
149+ dimRoundingMode ?: 'floor' | 'round' | 'ceil' ) : T {
120150 let x4D = x as Tensor4D ;
121151 let reshapedTo4D = false ;
122152
@@ -142,13 +172,24 @@ export class ConvOps {
142172 x4D . shape [ 3 ] === filter . shape [ 2 ] ,
143173 `Error in conv2d: depth of input (${ x4D . shape [ 3 ] } ) must match ` +
144174 `input depth for filter ${ filter . shape [ 2 ] } .` ) ;
145-
146- const dilations = 1 ;
175+ util . assert (
176+ eitherStridesOrDilationsAreOne ( strides , dilations ) ,
177+ 'Error in conv2D: Either strides or dilations must be 1.' +
178+ `Got strides ${ strides } and dilations '${ dilations } '` ) ;
179+ util . assert (
180+ dataFormat === 'NHWC' ,
181+ `Error in conv2d: got dataFormat of ${
182+ dataFormat } but only NHWC is currently supported.`) ;
147183
148184 const convInfo = conv_util . computeConv2DInfo (
149185 x4D . shape , filter . shape , strides , dilations , pad , dimRoundingMode ) ;
150186
151187 const grad = ( dy : Tensor4D ) => {
188+ util . assert (
189+ tupleValuesAreOne ( dilations ) ,
190+ 'Error in gradient of conv2D: dilation rates greater than 1 are not' +
191+ `yet supported in gradients. Got dilations '${ dilations } '` ) ;
192+
152193 return {
153194 x : ( ) => ConvOps . conv2dDerInput ( x4D . shape , dy , filter , strides , pad ) ,
154195 filter : ( ) =>
@@ -375,9 +416,13 @@ export class ConvOps {
375416 * https://www.tensorflow.org/api_guides/python/nn#Convolution)
376417 * @param dilations The dilation rates: `[dilationHeight, dilationWidth]`
377418 * in which we sample input values across the height and width dimensions
378- * in atrous convolution. Defaults to `[1, 1]`. If `dilations ` is a single
419+ * in atrous convolution. Defaults to `[1, 1]`. If `rate ` is a single
379420 * number, then `dilationHeight == dilationWidth`. If it is greater than
380421 * 1, then all values of `strides` must be 1.
422+ * @param dataFormat: An optional string from: "NHWC", "NCHW". Defaults to
423+ * "NHWC". Specify the data format of the input and output data. With the
424+ * default format "NHWC", the data is stored in the order of: [batch,
425+ * height, width, channels]. Only "NHWC" is currently supported.
381426 * @param dimRoundingMode The rounding mode used when computing output
382427 * dimensions if pad is a number. If none is provided, it will not round
383428 * and error if the output is of fractional size.
@@ -386,7 +431,8 @@ export class ConvOps {
386431 @operation
387432 static depthwiseConv2d < T extends Tensor3D | Tensor4D > (
388433 input : T , filter : Tensor4D , strides : [ number , number ] | number ,
389- pad : 'valid' | 'same' | number , dilations : [ number , number ] | number = [ 1 , 1 ] ,
434+ pad : 'valid' | 'same' | number , dataFormat : 'NHWC' | 'NCHW' = 'NHWC' ,
435+ dilations : [ number , number ] | number = [ 1 , 1 ] ,
390436 dimRoundingMode ?: 'floor' | 'round' | 'ceil' ) : T {
391437 let input4D = input as Tensor4D ;
392438 let reshapedTo4D = false ;
@@ -410,11 +456,11 @@ export class ConvOps {
410456 if ( dilations == null ) {
411457 dilations = [ 1 , 1 ] ;
412458 }
413- const [ dilationHeight , dilationWidth ] = parseTupleParam ( dilations ) ;
414459 util . assert (
415- dilationHeight === 1 && dilationWidth === 1 ,
416- 'Error in depthwiseConv2D: dilation rates greater than 1 are not yet ' +
417- `supported. Got dilations '${ dilations } '` ) ;
460+ eitherStridesOrDilationsAreOne ( strides , dilations ) ,
461+ 'Error in depthwiseConv2d: Either strides or dilations must be 1.' +
462+ `Got strides ${ strides } and dilations '${ dilations } '` ) ;
463+
418464 if ( dimRoundingMode != null ) {
419465 util . assert (
420466 util . isInt ( pad as number ) ,
@@ -438,3 +484,14 @@ export class ConvOps {
438484function parseTupleParam ( param : number | [ number , number ] ) : [ number , number ] {
439485 return typeof param === 'number' ? [ param , param ] : param ;
440486}
487+
488+ function tupleValuesAreOne ( param : number | [ number , number ] ) : boolean {
489+ const [ dimA , dimB ] = parseTupleParam ( param ) ;
490+ return dimA === 1 && dimB === 1 ;
491+ }
492+
493+ function eitherStridesOrDilationsAreOne (
494+ strides : number | [ number , number ] ,
495+ dilations : number | [ number , number ] ) : boolean {
496+ return tupleValuesAreOne ( strides ) || tupleValuesAreOne ( dilations ) ;
497+ }
0 commit comments