Skip to content
This repository was archived by the owner on Aug 15, 2019. It is now read-only.

Commit ae1dd4a

Browse files
AndreasMadsenNikhil Thorat
authored andcommitted
Fix tf.pow for zero exponent (#1823)
BUG This fixes tf.pow(0, 0), which in WebGL1 returns NaN but in WebGL2 and JavaScript return 1. Note that TensorFlow also defines pow(0,0) = 1.
1 parent bbc2ea4 commit ae1dd4a

File tree

3 files changed

+20
-2
lines changed

3 files changed

+20
-2
lines changed

src/backends/webgl/binaryop_gpu.ts

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ export const MUL = 'return a * b;';
3333
export const DIV = `
3434
if (b == 0.0) {
3535
return NAN;
36-
}
36+
}
3737
if (a == b) {
3838
return 1.0;
3939
};
@@ -59,6 +59,9 @@ export const POW = `
5959
if(a < 0.0 && floor(b) < b){
6060
return NAN;
6161
}
62+
if (b == 0.0) {
63+
return 1.0;
64+
}
6265
return (round(mod(b, 2.0)) != 1) ?
6366
pow(abs(a), b) : sign(a) * pow(abs(a), b);
6467
`;

src/backends/webgl/binaryop_packed_gpu.ts

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ export const DIV = `
5555
} else if(a.w == b.w) {
5656
result.w = 1.;
5757
}
58-
58+
5959
return result;
6060
`;
6161

@@ -88,6 +88,13 @@ export const POW = `
8888
vec4 multiplier = sign(a) * isModRound1 + (vec4(1.0) - isModRound1);
8989
vec4 result = multiplier * pow(abs(a), b);
9090
91+
// Ensure that a^0 = 1, including 0^0 = 1 as this correspond to TF and JS
92+
bvec4 isExpZero = equal(b, vec4(0.0));
93+
result.r = isExpZero.r ? 1.0 : result.r;
94+
result.g = isExpZero.g ? 1.0 : result.g;
95+
result.b = isExpZero.b ? 1.0 : result.b;
96+
result.a = isExpZero.a ? 1.0 : result.a;
97+
9198
vec4 isNaN = vec4(lessThan(a, vec4(0.0))) * vec4(lessThan(floor(b), b));
9299
` +
93100
CHECK_NAN_SNIPPET + `

src/ops/arithmetic_test.ts

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -738,6 +738,14 @@ describeWithFlags('pow', ALL_ENVS, () => {
738738
expectArraysClose(await result.data(), [NaN, 27, NaN, 0], 0.05);
739739
});
740740

741+
it('exponent of 0 returns 1', async () => {
742+
const a = tf.tensor1d([-2, -1, 0, 1, 2]);
743+
const b = tf.scalar(0);
744+
745+
const result = tf.pow(a, b);
746+
expectArraysClose(await result.data(), [1, 1, 1, 1, 1]);
747+
});
748+
741749
it('handles non int32 exponent param', async () => {
742750
const a = tf.tensor1d([2, 4]);
743751
const b = tf.tensor1d([.5, 1.2]);

0 commit comments

Comments
 (0)