@@ -77,12 +77,55 @@ def forward(self, inputs, targets=None):
7777 training: boolean, whether in training mode or not.
7878
7979 Returns:
80- If targets is defined, then return logits for each word in the target
81- sequence. float tensor with shape [batch_size, target_length, vocab_size]
82- If target is none, then generate output sequence one token at a time.
83- returns a dictionary {
84- outputs: [batch_size, decoded length]
85- scores: [batch_size, float]}
80+ If targets is defined:
81+ Logits for each word in the target sequence:
82+ float tensor with shape [batch_size, target_length, vocab_size]
83+ Self-attention weights for encoder part:
84+ a dictionary of float tensors {
85+ "layer_0": [batch_size, number_of_heads, source_length, source_length],
86+ "layer_1": [batch_size, number_of_heads, source_length, source_length],
87+ ...
88+ }
89+ Weights for decoder part:
90+ a dictionary of dictionary of float tensors {
91+ "self": {
92+ "layer_0": [batch_size, number_of_heads, target_length, target_length],
93+ "layer_1": [batch_size, number_of_heads, target_length, target_length],
94+ ...
95+ }
96+ "enc_dec": {
97+ "layer_0": [batch_size, number_of_heads, source_length, target_length],
98+ "layer_1": [batch_size, number_of_heads, source_length, target_length],
99+ ...
100+ }
101+ }
102+
103+ If target is none:
104+ Auto-regressive beam-search decoding to generate output each one time step:
105+ a dictionary {
106+ outputs: [batch_size, decoded length]
107+ scores: [batch_size, float]}
108+ }
109+ Weights for decoder part:
110+ a dictionary of dictionary of float tensors {
111+ "self": {
112+ "layer_0": [batch_size, number_of_heads, target_length, target_length],
113+ "layer_1": [batch_size, number_of_heads, target_length, target_length],
114+ ...
115+ }
116+ "enc_dec": {
117+ "layer_0": [batch_size, number_of_heads, source_length, target_length],
118+ "layer_1": [batch_size, number_of_heads, source_length, target_length],
119+ ...
120+ }
121+ }
122+ Self-attention weights for encoder part:
123+ a dictionary of float tensors {
124+ "layer_0": [batch_size, number_of_heads, source_length, source_length],
125+ "layer_1": [batch_size, number_of_heads, source_length, source_length],
126+ ...
127+ }
128+
86129 """
87130 # # Variance scaling is used here because it seems to work in many problems.
88131 # # Other reasonable initializers may also work just as well.
@@ -118,6 +161,7 @@ def encode(self, inputs, attention_bias):
118161
119162 Returns:
120163 float tensor with shape [batch_size, input_length, hidden_size]
164+
121165 """
122166
123167 # Prepare inputs to the layer stack by adding positional encodings and
@@ -223,7 +267,12 @@ def symbols_to_logits_fn(ids, i, cache):
223267 return symbols_to_logits_fn , weights
224268
225269 def predict (self , encoder_outputs , encoder_decoder_attention_bias ):
226- """Return predicted sequence."""
270+ """
271+
272+ Return predicted sequence, and decoder attention weights.
273+
274+
275+ """
227276 batch_size = tf .shape (encoder_outputs )[0 ]
228277 input_length = tf .shape (encoder_outputs )[1 ]
229278 max_decode_length = input_length + self .params .extra_decode_length
@@ -263,7 +312,15 @@ def predict(self, encoder_outputs, encoder_decoder_attention_bias):
263312 top_decoded_ids = decoded_ids [:, 0 , 1 :]
264313 top_scores = scores [:, 0 ]
265314
266- return {"outputs" : top_decoded_ids , "scores" : top_scores }, weights
315+ # post-process the weight attention
316+ for i , weight in enumerate (weights ):
317+ if (i == 0 ):
318+ w = weight
319+ else :
320+ for k in range (len (w ['self' ])):
321+ w ['self' ]['layer_%d' % k ] = tf .concat ([w ['self' ]['layer_%d' % k ], weight ['self' ]['layer_%d' % k ]], 3 )
322+ w ['enc_dec' ]['layer_%d' % k ] = tf .concat ([w ['enc_dec' ]['layer_%d' % k ], weight ['enc_dec' ]['layer_%d' % k ]], 2 )
323+ return {"outputs" : top_decoded_ids , "scores" : top_scores }, w
267324
268325
269326class LayerNormalization (tl .layers .Layer ):
0 commit comments