44import os
55import unittest
66
7- os .environ ['TF_CPP_MIN_LOG_LEVEL' ] = '3'
8-
97import numpy as np
108import tensorflow as tf
119import tensorlayer as tl
1412from tensorlayer .models .transformer import Transformer
1513from tests .utils import CustomTestCase
1614from tensorlayer .models .transformer .utils import metrics
17- from tensorlayer .cost import cross_entropy_seq
1815from tensorlayer .optimizers import lazyAdam as optimizer
16+ from tensorlayer .models .transformer .utils import attention_visualisation
1917import time
2018
2119
@@ -51,7 +49,7 @@ class Model_SEQ2SEQ_Test(CustomTestCase):
5149
5250 @classmethod
5351 def setUpClass (cls ):
54- cls .batch_size = 16
52+ cls .batch_size = 50
5553
5654 cls .embedding_size = 32
5755 cls .dec_seq_length = 5
@@ -66,7 +64,7 @@ def setUpClass(cls):
6664
6765 assert cls .src_len == cls .tgt_len
6866
69- cls .num_epochs = 1000
67+ cls .num_epochs = 20
7068 cls .n_step = cls .src_len // cls .batch_size
7169
7270 @classmethod
@@ -99,8 +97,8 @@ def test_basic_simpleSeq2Seq(self):
9997
10098 grad = tape .gradient (loss , model_ .all_weights )
10199 optimizer .apply_gradients (zip (grad , model_ .all_weights ))
102-
103100
101+
104102 total_loss += loss
105103 n_iter += 1
106104 print (time .time ()- t )
@@ -115,5 +113,20 @@ def test_basic_simpleSeq2Seq(self):
115113 print ('Epoch [{}/{}]: loss {:.4f}' .format (epoch + 1 , self .num_epochs , total_loss / n_iter ))
116114
117115
116+ # visualise the self-attention weights at encoder
117+ trainX , trainY = shuffle (self .trainX , self .trainY )
118+ X = [trainX [0 ]]
119+ Y = [trainY [0 ]]
120+ logits , weights_encoder , weights_decoder = model_ (inputs = X , targets = Y )
121+ attention_visualisation .plot_attention_weights (weights_encoder ["layer_0" ], X [0 ].numpy (), X [0 ].numpy ())
122+
123+ # visualise the self-attention weights at encoder
124+ trainX , trainY = shuffle (self .trainX , self .trainY )
125+ X = [trainX [0 ]]
126+ Y = [trainY [0 ]]
127+ logits , weights_encoder , weights_decoder = model_ (inputs = X , targets = Y )
128+ attention_visualisation .plot_attention_weights (weights_decoder ["enc_dec" ]["layer_0" ], X [0 ].numpy (), Y [0 ])
129+
130+
118131if __name__ == '__main__' :
119132 unittest .main ()
0 commit comments