Skip to content

Commit 5b85af7

Browse files
Lingjun LiuLingjun Liu
authored andcommitted
transformer updated
1 parent ac9d43b commit 5b85af7

File tree

13 files changed

+2386
-0
lines changed

13 files changed

+2386
-0
lines changed

tensorlayer/models/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,3 +9,4 @@
99
from .vgg import *
1010
from .seq2seq import Seq2seq
1111
from .seq2seq_with_attention import Seq2seqLuongAttention
12+
from .transformer import *
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
from .attention_layer import *
2+
from .transformer import Transformer
3+
from .beamsearchHelper import *
4+
from .feedforward_layer import *
5+
from .embedding_layer import *
6+
from .utils import *
Lines changed: 186 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,186 @@
1+
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
"""Implementation of multiheaded attention and self-attention layers."""
16+
17+
from __future__ import absolute_import
18+
from __future__ import division
19+
from __future__ import print_function
20+
21+
import tensorflow as tf
22+
import tensorlayer as tl
23+
24+
25+
class MultiHeadAttentionLayer(tl.layers.Layer):
26+
"""Multi-headed attention layer."""
27+
28+
def __init__(self, num_heads, hidden_size, keep_prob):
29+
"""Initialize Attention.
30+
31+
Args:
32+
hidden_size: int, output dim of hidden layer.
33+
num_heads: int, number of heads to repeat the same attention structure.
34+
keep_prob: float, keep rate for dropout mechanism inside attention for training.
35+
"""
36+
if hidden_size % num_heads:
37+
raise ValueError(
38+
"Hidden size ({}) must be divisible by the number of heads ({}).".format(hidden_size, num_heads)
39+
)
40+
41+
super(MultiHeadAttentionLayer, self).__init__()
42+
self.hidden_size = hidden_size
43+
self.num_heads = num_heads
44+
self.attention_dropout = 1 - keep_prob
45+
46+
self.build(None)
47+
self._built = True
48+
49+
def get_config(self):
50+
return {
51+
"hidden_size": self.hidden_size,
52+
"num_heads": self.num_heads,
53+
"attention_dropout": self.attention_dropout,
54+
}
55+
56+
def build(self, inputs_shape):
57+
# Transformation for linearly projecting the queries, keys, and values.
58+
self.q_transformation = self._get_weights(
59+
"q_project", shape=(self.hidden_size, self.hidden_size), init=tf.keras.initializers.get('glorot_uniform')
60+
)
61+
self.v_transformation = self._get_weights(
62+
"v_project", shape=(self.hidden_size, self.hidden_size), init=tf.keras.initializers.get('glorot_uniform')
63+
)
64+
self.k_transformation = self._get_weights(
65+
"k_project", shape=(self.hidden_size, self.hidden_size), init=tf.keras.initializers.get('glorot_uniform')
66+
)
67+
self.out_transformation = self._get_weights(
68+
"out_project", shape=(self.hidden_size, self.hidden_size), init=tf.keras.initializers.get('glorot_uniform')
69+
)
70+
71+
def split_heads(self, x):
72+
"""Split x into different heads, and transpose the resulting value.
73+
74+
The tensor is transposed to insure the inner dimensions hold the correct
75+
values during the matrix multiplication.
76+
77+
Args:
78+
x: A tensor with shape [batch_size, length, hidden_size]
79+
80+
Returns:
81+
A tensor with shape [batch_size, num_heads, length, hidden_size/num_heads]
82+
"""
83+
with tf.name_scope("split_heads"):
84+
batch_size = tf.shape(x)[0]
85+
length = tf.shape(x)[1]
86+
87+
# Calculate depth of last dimension after it has been split.
88+
depth = (self.hidden_size // self.num_heads)
89+
90+
# Split the last dimension
91+
x = tf.reshape(x, [batch_size, length, self.num_heads, depth])
92+
93+
# Transpose the result
94+
return tf.transpose(x, [0, 2, 1, 3])
95+
96+
def combine_heads(self, x):
97+
"""Combine tensor that has been split.
98+
99+
Args:
100+
x: A tensor [batch_size, num_heads, length, hidden_size/num_heads]
101+
102+
Returns:
103+
A tensor with shape [batch_size, length, hidden_size]
104+
"""
105+
with tf.name_scope("combine_heads"):
106+
batch_size = tf.shape(x)[0]
107+
length = tf.shape(x)[2]
108+
x = tf.transpose(x, [0, 2, 1, 3]) # --> [batch, length, num_heads, depth]
109+
return tf.reshape(x, [batch_size, length, self.hidden_size])
110+
111+
def forward(self, inputs, mask, cache=None):
112+
"""Apply attention mechanism to x and y.
113+
114+
Args:
115+
x: a tensor with shape [batch_size, length_x, hidden_size]
116+
y: a tensor with shape [batch_size, length_y, hidden_size]
117+
mask: attention bias that will be added to the result of the dot product.
118+
training: boolean, whether in training mode or not.
119+
cache: (Used during prediction) dictionary with tensors containing results
120+
of previous attentions. The dictionary must have the items:
121+
{"k": tensor with shape [batch_size, i, key_channels],
122+
"v": tensor with shape [batch_size, i, value_channels]}
123+
where i is the current decoded length.
124+
125+
Returns:
126+
Attention layer output with shape [batch_size, length_x, hidden_size]
127+
"""
128+
# Linearly project the query (q), key (k) and value (v) using different
129+
# learned projections. This is in preparation of splitting them into
130+
# multiple heads. Multi-head attention uses multiple queries, keys, and
131+
# values rather than regular attention (which uses a single q, k, v).
132+
133+
if (len(inputs) == 2):
134+
q = inputs[0]
135+
k = v = inputs[1]
136+
137+
if (len(inputs) == 3):
138+
q = inputs[0]
139+
k = inputs[1]
140+
v = inputs[2]
141+
142+
q = tf.tensordot(q, self.q_transformation, axes=[[2], [0]])
143+
k = tf.tensordot(k, self.k_transformation, axes=[[2], [0]])
144+
v = tf.tensordot(v, self.v_transformation, axes=[[2], [0]])
145+
146+
if cache is not None:
147+
148+
# Combine cached keys and values with new keys and values.
149+
k = tf.concat([cache["k"], k], axis=1)
150+
v = tf.concat([cache["v"], v], axis=1)
151+
152+
# Update cache
153+
cache["k"] = k
154+
cache["v"] = v
155+
156+
# Split q, k, v into heads.
157+
q = self.split_heads(q)
158+
k = self.split_heads(k)
159+
v = self.split_heads(v) #(Batch, num_head, length_v, dk)
160+
161+
# Scale q to prevent the dot product between q and k from growing too large.
162+
depth = (self.hidden_size // self.num_heads)
163+
q *= depth**-0.5
164+
165+
# Calculate dot product attention
166+
logits = tf.matmul(q, k, transpose_b=True) #(Batch, num_head, length_q, length_k)
167+
logits += mask
168+
weights = tf.nn.softmax(logits, name="attention_weights") #(Batch, num_head, length_q, length_k)
169+
if self.is_train:
170+
weights = tf.nn.dropout(weights, rate=self.attention_dropout)
171+
172+
attention_output = tf.matmul(weights, v)
173+
174+
# Recombine heads --> [batch_size, length, hidden_size]
175+
attention_output = self.combine_heads(attention_output)
176+
177+
# Run the combined outputs through another linear projection layer.
178+
attention_output = tf.tensordot(attention_output, self.out_transformation, axes=[[2], [0]])
179+
return attention_output
180+
181+
182+
class SelfAttentionLayer(MultiHeadAttentionLayer):
183+
"""Multiheaded self-attention layer."""
184+
185+
def forward(self, inputs, mask, cache=None):
186+
return super(SelfAttentionLayer, self).forward(inputs=[inputs, inputs], mask=mask, cache=cache)
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .beam_search import *
Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
"""Beam search in TF v2.
16+
"""
17+
18+
import tensorflow as tf
19+
import tensorlayer.models.transformer.beamsearchHelper.beam_search_v1 as v1
20+
21+
_StateKeys = v1._StateKeys # pylint: disable=protected-access
22+
23+
24+
class SequenceBeamSearchV2(v1.SequenceBeamSearch):
25+
"""Implementation of beam search loop in v2."""
26+
27+
def search(self, initial_ids, initial_cache):
28+
"""Beam search for sequences with highest scores."""
29+
state, state_shapes = self._create_initial_state(initial_ids, initial_cache)
30+
finished_state = tf.while_loop(
31+
self._continue_search, self._search_step, loop_vars=[state], shape_invariants=[state_shapes],
32+
parallel_iterations=1, back_prop=False
33+
)
34+
finished_state = finished_state[0]
35+
36+
alive_seq = finished_state[_StateKeys.ALIVE_SEQ]
37+
alive_log_probs = finished_state[_StateKeys.ALIVE_LOG_PROBS]
38+
finished_seq = finished_state[_StateKeys.FINISHED_SEQ]
39+
finished_scores = finished_state[_StateKeys.FINISHED_SCORES]
40+
finished_flags = finished_state[_StateKeys.FINISHED_FLAGS]
41+
42+
# Account for corner case where there are no finished sequences for a
43+
# particular batch item. In that case, return alive sequences for that batch
44+
# item.
45+
finished_seq = tf.where(tf.reduce_any(finished_flags, 1), finished_seq, alive_seq)
46+
finished_scores = tf.where(tf.reduce_any(finished_flags, 1), finished_scores, alive_log_probs)
47+
return finished_seq, finished_scores
48+
49+
50+
def sequence_beam_search(
51+
symbols_to_logits_fn, initial_ids, initial_cache, vocab_size, beam_size, alpha, max_decode_length, eos_id
52+
):
53+
"""Search for sequence of subtoken ids with the largest probability.
54+
55+
Args:
56+
symbols_to_logits_fn: A function that takes in ids, index, and cache as
57+
arguments. The passed in arguments will have shape:
58+
ids -> [batch_size * beam_size, index]
59+
index -> [] (scalar)
60+
cache -> nested dictionary of tensors [batch_size * beam_size, ...]
61+
The function must return logits and new cache.
62+
logits -> [batch * beam_size, vocab_size]
63+
new cache -> same shape/structure as inputted cache
64+
initial_ids: Starting ids for each batch item.
65+
int32 tensor with shape [batch_size]
66+
initial_cache: dict containing starting decoder variables information
67+
vocab_size: int size of tokens
68+
beam_size: int number of beams
69+
alpha: float defining the strength of length normalization
70+
max_decode_length: maximum length to decoded sequence
71+
eos_id: int id of eos token, used to determine when a sequence has finished
72+
73+
Returns:
74+
Top decoded sequences [batch_size, beam_size, max_decode_length]
75+
sequence scores [batch_size, beam_size]
76+
"""
77+
batch_size = tf.shape(initial_ids)[0]
78+
79+
sbs = SequenceBeamSearchV2(
80+
symbols_to_logits_fn, vocab_size, batch_size, beam_size, alpha, max_decode_length, eos_id
81+
)
82+
return sbs.search(initial_ids, initial_cache)
83+
84+
85+
def _expand_to_same_rank(tensor, target):
86+
"""Expands a given tensor to target's rank to be broadcastable.
87+
88+
Args:
89+
tensor: input tensor to tile. Shape: [b, d1, ..., da]
90+
target: target tensor. Shape: [b, d1, ..., da, ..., dn]
91+
92+
Returns:
93+
Tiled tensor of shape [b, d1, ..., da, 1, ..., 1] with same rank of target.
94+
95+
Raises:
96+
ValueError, if the shape rank of rank tensor/target is None.
97+
"""
98+
if tensor.shape.rank is None:
99+
raise ValueError("Expect rank for tensor shape, but got None.")
100+
if target.shape.rank is None:
101+
raise ValueError("Expect rank for target shape, but got None.")
102+
103+
with tf.name_scope("expand_rank"):
104+
diff_rank = target.shape.rank - tensor.shape.rank
105+
for _ in range(diff_rank):
106+
tensor = tf.expand_dims(tensor, -1)
107+
return tensor

0 commit comments

Comments
 (0)