Skip to content

Commit 83d3f76

Browse files
omalleyt12tensorflow-copybara
authored andcommitted
Fix test where shapes are not matched to keras.Input shapes.
PiperOrigin-RevId: 297461518
1 parent eec7323 commit 83d3f76

File tree

1 file changed

+5
-3
lines changed

1 file changed

+5
-3
lines changed

neural_structured_learning/keras/layers/pairwise_distance_test.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -89,12 +89,13 @@ def _distance_fn(x, y):
8989
return np.mean(np.square(x - y))
9090

9191
# Common input.
92-
sources = np.array([1., 1., 1., 1.])
93-
targets = np.array([[4., 3., 2., 1.]])
92+
sources = np.array([[1., 1., 1., 1.]])
93+
targets = np.array([[[4., 3., 2., 1.]]])
9494
unweighted_distance = _distance_fn(sources, targets)
9595

9696
def _make_symbolic_weights_model():
9797
"""Makes a model where the weights are provided as input."""
98+
# Shape doesn't include batch dimension.
9899
inputs = {
99100
'sources': tf.keras.Input(4),
100101
'targets': tf.keras.Input((1, 4)),
@@ -104,7 +105,7 @@ def _make_symbolic_weights_model():
104105
outputs = pairwise_distance_fn(**inputs)
105106
return tf.keras.Model(inputs=inputs, outputs=outputs)
106107

107-
weights = np.array([[2.]])
108+
weights = np.array([[[2.]]])
108109
expected_distance = unweighted_distance * weights
109110
model = _make_symbolic_weights_model()
110111
self.assertNear(
@@ -117,6 +118,7 @@ def _make_symbolic_weights_model():
117118

118119
def _make_fixed_weights_model(weights):
119120
"""Makes a model where the weights are a static constant."""
121+
# Shape doesn't include batch dimension.
120122
inputs = {
121123
'sources': tf.keras.Input(4),
122124
'targets': tf.keras.Input((1, 4)),

0 commit comments

Comments
 (0)