|
|
|
@ -11,14 +11,13 @@
|
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
|
# limitations under the License.
|
|
|
|
|
|
|
|
|
|
import layers
|
|
|
|
|
|
|
|
|
|
__all__ = [
|
|
|
|
|
"simple_img_conv_pool",
|
|
|
|
|
"sequence_conv_pool",
|
|
|
|
|
"glu",
|
|
|
|
|
"dot_product_attention",
|
|
|
|
|
"scaled_dot_product_attention",
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -160,7 +159,11 @@ def glu(input, dim=-1):
|
|
|
|
|
return out
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def dot_product_attention(querys, keys, values):
|
|
|
|
|
def scaled_dot_product_attention(queries,
|
|
|
|
|
keys,
|
|
|
|
|
values,
|
|
|
|
|
num_heads=1,
|
|
|
|
|
dropout_rate=0.):
|
|
|
|
|
"""
|
|
|
|
|
The dot-product attention.
|
|
|
|
|
|
|
|
|
@ -174,39 +177,162 @@ def dot_product_attention(querys, keys, values):
|
|
|
|
|
|
|
|
|
|
.. math::
|
|
|
|
|
|
|
|
|
|
Attention(Q, K, V)= softmax(QK^\mathrm{T})V
|
|
|
|
|
Attention(Q, K, V)= softmax(QK^\mathrm{T})V
|
|
|
|
|
|
|
|
|
|
Refer to `Attention Is All You Need
|
|
|
|
|
<https://arxiv.org/pdf/1706.03762.pdf>`_.
|
|
|
|
|
|
|
|
|
|
Note that batch data containing sequences with different lengths is not
|
|
|
|
|
supported by this because of the (batch) matrix multipication.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
query (Variable): The input variable which is a Tensor or LoDTensor.
|
|
|
|
|
key (Variable): The input variable which is a Tensor or LoDTensor.
|
|
|
|
|
value (Variable): The input variable which is a Tensor or LoDTensor.
|
|
|
|
|
|
|
|
|
|
queries (Variable): The input variable which should be a 3-D Tensor.
|
|
|
|
|
keys (Variable): The input variable which should be a 3-D Tensor.
|
|
|
|
|
values (Variable): The input variable which should be a 3-D Tensor.
|
|
|
|
|
num_heads (int): Head number to compute the scaled dot product
|
|
|
|
|
attention. Default value is 1.
|
|
|
|
|
dropout_rate (float): The dropout rate to drop the attention weight.
|
|
|
|
|
Default value is 0.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
tuple: The Tensor variables representing the output and attention scores.
|
|
|
|
|
|
|
|
|
|
Variable: A 3-D Tensor computed by multi-head scaled dot product
|
|
|
|
|
attention.
|
|
|
|
|
|
|
|
|
|
Raises:
|
|
|
|
|
|
|
|
|
|
ValueError: If input queries, keys, values are not 3-D Tensors.
|
|
|
|
|
|
|
|
|
|
NOTE:
|
|
|
|
|
1. When num_heads > 1, three linear projections are learned respectively
|
|
|
|
|
to map input queries, keys and values into queries', keys' and values'.
|
|
|
|
|
queries', keys' and values' have the same shapes with queries, keys
|
|
|
|
|
and values.
|
|
|
|
|
|
|
|
|
|
1. When num_heads == 1, scaled_dot_product_attention has no learnable
|
|
|
|
|
parameters.
|
|
|
|
|
|
|
|
|
|
Examples:
|
|
|
|
|
.. code-block:: python
|
|
|
|
|
|
|
|
|
|
# Suppose q, k, v are tensor variables with the following shape:
|
|
|
|
|
# Suppose q, k, v are Tensors with the following shape:
|
|
|
|
|
# q: [3, 5, 9], k: [3, 6, 9], v: [3, 6, 10]
|
|
|
|
|
out, attn_scores = fluid.nets.dot_product_attention(q, k, v)
|
|
|
|
|
out.shape # [3, 5, 10]
|
|
|
|
|
attn_scores.shape # [3, 5, 6]
|
|
|
|
|
|
|
|
|
|
contexts = fluid.nets.scaled_dot_product_attention(q, k, v)
|
|
|
|
|
contexts.shape # [3, 5, 10]
|
|
|
|
|
"""
|
|
|
|
|
assert keys.shape[-2] == values.shape[
|
|
|
|
|
-2], 'The shapes of keys and values mismatch.'
|
|
|
|
|
assert querys.shape[-1] == keys.shape[
|
|
|
|
|
-1], 'The shapes of querys and keys mismatch.'
|
|
|
|
|
product = layers.matmul(x=querys, y=keys, transpose_y=True)
|
|
|
|
|
attn_scores = layers.reshape(
|
|
|
|
|
if not (len(queries.shape) == len(keys.shape) == len(values.shape) == 3):
|
|
|
|
|
raise ValueError(
|
|
|
|
|
"Inputs quries, keys and values should all be 3-D tensors.")
|
|
|
|
|
|
|
|
|
|
if queries.shape[-1] != keys.shape[-1]:
|
|
|
|
|
raise ValueError(
|
|
|
|
|
"The hidden size of queries and keys should be the same.")
|
|
|
|
|
if keys.shape[-2] != values.shape[-2]:
|
|
|
|
|
raise ValueError(
|
|
|
|
|
"The max sequence length in query batch and in key batch "
|
|
|
|
|
"should be the same.")
|
|
|
|
|
if keys.shape[-1] % num_heads != 0:
|
|
|
|
|
raise ValueError("The hidden size of keys (%d) must be divisible "
|
|
|
|
|
"by the number of attention heads (%d)." %
|
|
|
|
|
(keys.shape[-1], num_heads))
|
|
|
|
|
if values.shape[-1] % num_heads != 0:
|
|
|
|
|
raise ValueError("The hidden size of values (%d) must be divisible "
|
|
|
|
|
"by the number of attention heads (%d)." %
|
|
|
|
|
(values.shape[-1], num_heads))
|
|
|
|
|
|
|
|
|
|
def __compute_qkv(queries, keys, values, num_heads):
|
|
|
|
|
"""
|
|
|
|
|
Add linear projection to queries, keys, and values.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
queries(Tensor): a 3-D input Tensor.
|
|
|
|
|
keys(Tensor): a 3-D input Tensor.
|
|
|
|
|
values(Tensor): a 3-D input Tensor.
|
|
|
|
|
num_heads(int): The number of heads. Linearly project the inputs
|
|
|
|
|
ONLY when num_heads > 1.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
Tensor: linearly projected output Tensors: queries', keys' and
|
|
|
|
|
values'. They have the same shapes with queries, keys and
|
|
|
|
|
values.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
if num_heads == 1:
|
|
|
|
|
return queries, keys, values
|
|
|
|
|
|
|
|
|
|
q = layers.fc(input=queries, size=queries.shape[-1], num_flatten_dims=2)
|
|
|
|
|
k = layers.fc(input=keys, size=keys.shape[-1], num_flatten_dims=2)
|
|
|
|
|
v = layers.fc(input=values, size=values.shape[-1], num_flatten_dims=2)
|
|
|
|
|
return q, k, v
|
|
|
|
|
|
|
|
|
|
def __split_heads(x, num_heads):
|
|
|
|
|
"""
|
|
|
|
|
Reshape the last dimension of inpunt tensor x so that it becomes two
|
|
|
|
|
dimensions.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
x(Tensor): a 3-D input Tensor.
|
|
|
|
|
num_heads(int): The number of heads.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
Tensor: a Tensor with shape [..., n, m/num_heads], where m is size
|
|
|
|
|
of the last dimension of x.
|
|
|
|
|
"""
|
|
|
|
|
if num_heads == 1:
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
hidden_size = x.shape[-1]
|
|
|
|
|
# reshape the 3-D input: [batch_size, max_sequence_length, hidden_dim]
|
|
|
|
|
# into a 4-D output:
|
|
|
|
|
# [batch_size, max_sequence_length, num_heads, hidden_size_per_head].
|
|
|
|
|
reshaped = layers.reshape(
|
|
|
|
|
x=x,
|
|
|
|
|
shape=list(x.shape[:-1]) + [num_heads, hidden_size // num_heads])
|
|
|
|
|
|
|
|
|
|
# permuate the dimensions into:
|
|
|
|
|
# [batch_size, num_heads, max_sequence_len, hidden_size_per_head]
|
|
|
|
|
return layers.transpose(x=reshaped, perm=[0, 2, 1, 3])
|
|
|
|
|
|
|
|
|
|
def __combine_heads(x):
|
|
|
|
|
"""
|
|
|
|
|
Reshape the last two dimensions of inpunt tensor x so that it becomes
|
|
|
|
|
one dimension.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
x(Tensor): a 4-D input Tensor with shape
|
|
|
|
|
[bs, num_heads, max_sequence_length, hidden_dim].
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
Tensor: a Tensor with shape
|
|
|
|
|
[bs, max_sequence_length, num_heads * hidden_dim].
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
if len(x.shape) == 3: return x
|
|
|
|
|
if len(x.shape) != 4:
|
|
|
|
|
raise ValueError("Input(x) should be a 4-D Tensor.")
|
|
|
|
|
|
|
|
|
|
trans_x = layers.transpose(x, perm=[0, 2, 1, 3])
|
|
|
|
|
return layers.reshape(
|
|
|
|
|
x=trans_x,
|
|
|
|
|
shape=map(int, [
|
|
|
|
|
trans_x.shape[0], trans_x.shape[1],
|
|
|
|
|
trans_x.shape[2] * trans_x.shape[3]
|
|
|
|
|
]))
|
|
|
|
|
|
|
|
|
|
q, k, v = __compute_qkv(queries, keys, values, num_heads)
|
|
|
|
|
|
|
|
|
|
q = __split_heads(q, num_heads)
|
|
|
|
|
k = __split_heads(k, num_heads)
|
|
|
|
|
v = __split_heads(v, num_heads)
|
|
|
|
|
|
|
|
|
|
key_dim_per_head = keys.shape[-1] // num_heads
|
|
|
|
|
scaled_q = layers.scale(x=q, scale=key_dim_per_head**-0.5)
|
|
|
|
|
product = layers.matmul(x=k, y=scaled_q, transpose_y=True)
|
|
|
|
|
|
|
|
|
|
weights = layers.reshape(
|
|
|
|
|
x=layers.reshape(
|
|
|
|
|
x=product, shape=[-1, product.shape[-1]], act='softmax'),
|
|
|
|
|
x=product, shape=[-1, product.shape[-1]], act="softmax"),
|
|
|
|
|
shape=product.shape)
|
|
|
|
|
out = layers.matmul(attn_scores, values)
|
|
|
|
|
return out, attn_scores
|
|
|
|
|
if dropout_rate:
|
|
|
|
|
weights = layers.dropout(x, dropout_prob=dropout_rate, is_test=False)
|
|
|
|
|
ctx_multiheads = layers.matmul(weights, v)
|
|
|
|
|
return __combine_heads(ctx_multiheads)
|
|
|
|
|