Skip to content

Commit 005ab91

Browse files
Lingjun LiuLingjun Liu
authored andcommitted
attention visualisation
1 parent 6ecca88 commit 005ab91

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

tests/models/test_transformer.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ def tearDownClass(cls):
7474
pass
7575

7676
def test_basic_simpleSeq2Seq(self):
77-
77+
7878
model_ = Transformer(TINY_PARAMS)
7979

8080
# print(", ".join(x for x in [t.name for t in model_.trainable_weights]))
@@ -93,7 +93,7 @@ def test_basic_simpleSeq2Seq(self):
9393
with tf.GradientTape() as tape:
9494

9595
targets = Y
96-
logits = model_(inputs = X, targets = Y)
96+
logits, weights_encoder, weights_decoder = model_(inputs = X, targets = Y)
9797
logits = metrics.MetricLayer(self.vocab_size)([logits, targets])
9898
logits, loss = metrics.LossLayer(self.vocab_size, 0.1)([logits, targets])
9999

@@ -108,7 +108,7 @@ def test_basic_simpleSeq2Seq(self):
108108
model_.eval()
109109
test_sample = trainX[0:2, :]
110110
model_.eval()
111-
prediction = model_(inputs = test_sample)
111+
[prediction, weights_decoder], weights_encoder = model_(inputs = test_sample)
112112

113113
print("Prediction: >>>>> ", prediction["outputs"], "\n Target: >>>>> ", trainY[0:2, :], "\n\n")
114114

0 commit comments

Comments
 (0)