@@ -89,12 +89,13 @@ def _distance_fn(x, y):
8989 return np .mean (np .square (x - y ))
9090
9191 # Common input.
92- sources = np .array ([1. , 1. , 1. , 1. ])
93- targets = np .array ([[4. , 3. , 2. , 1. ]])
92+ sources = np .array ([[ 1. , 1. , 1. , 1. ] ])
93+ targets = np .array ([[[ 4. , 3. , 2. , 1. ] ]])
9494 unweighted_distance = _distance_fn (sources , targets )
9595
9696 def _make_symbolic_weights_model ():
9797 """Makes a model where the weights are provided as input."""
98+ # Shape doesn't include batch dimension.
9899 inputs = {
99100 'sources' : tf .keras .Input (4 ),
100101 'targets' : tf .keras .Input ((1 , 4 )),
@@ -104,7 +105,7 @@ def _make_symbolic_weights_model():
104105 outputs = pairwise_distance_fn (** inputs )
105106 return tf .keras .Model (inputs = inputs , outputs = outputs )
106107
107- weights = np .array ([[2. ]])
108+ weights = np .array ([[[ 2. ] ]])
108109 expected_distance = unweighted_distance * weights
109110 model = _make_symbolic_weights_model ()
110111 self .assertNear (
@@ -117,6 +118,7 @@ def _make_symbolic_weights_model():
117118
118119 def _make_fixed_weights_model (weights ):
119120 """Makes a model where the weights are a static constant."""
121+ # Shape doesn't include batch dimension.
120122 inputs = {
121123 'sources' : tf .keras .Input (4 ),
122124 'targets' : tf .keras .Input ((1 , 4 )),
0 commit comments