|
|
|
@ -182,28 +182,28 @@ def scaled_dot_product_attention(queries,
|
|
|
|
|
Refer to `Attention Is All You Need
|
|
|
|
|
<https://arxiv.org/pdf/1706.03762.pdf>`_.
|
|
|
|
|
|
|
|
|
|
Note that batch data containing sequences with different lengths is not
|
|
|
|
|
supported by this because of the (batch) matrix multipication.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
|
|
|
|
|
queries (Variable): The input variable which is a Tensor or
|
|
|
|
|
LoDTensor.
|
|
|
|
|
keys (Variable): The input variable which is a Tensor or LoDTensor.
|
|
|
|
|
values (Variable): The input variable which is a Tensor or
|
|
|
|
|
LoDTensor.
|
|
|
|
|
num_heads (int): Head number to compute the dot product attention.
|
|
|
|
|
dropout_rate (float): The dropout rate for attention weight.
|
|
|
|
|
queries (Variable): The input variable which should be a 3-D Tensor.
|
|
|
|
|
keys (Variable): The input variable which should be a 3-D Tensor.
|
|
|
|
|
values (Variable): The input variable which should be a 3-D Tensor.
|
|
|
|
|
num_heads (int): Head number to compute the scaled dot product
|
|
|
|
|
attention. Default value is 1.
|
|
|
|
|
dropout_rate (float): The dropout rate to drop the attention weight.
|
|
|
|
|
Default value is 0.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
Variable: The context Tensor computed by multi-head scaled dot product
|
|
|
|
|
|
|
|
|
|
Variable: A 3-D Tensor computed by multi-head scaled dot product
|
|
|
|
|
attention.
|
|
|
|
|
|
|
|
|
|
Examples:
|
|
|
|
|
.. code-block:: python
|
|
|
|
|
|
|
|
|
|
# Suppose q, k, v are tensor variables with the following
|
|
|
|
|
# shape: q: [3, 5, 9], k: [3, 6, 9], v: [3, 6, 10]
|
|
|
|
|
out, attn_scores = fluid.nets.dot_product_attention(q, k, v)
|
|
|
|
|
# Suppose q, k, v are Tensors with the following shape:
|
|
|
|
|
# q: [3, 5, 9], k: [3, 6, 9], v: [3, 6, 10]
|
|
|
|
|
|
|
|
|
|
contexts = fluid.nets.dot_product_attention(q, k, v)
|
|
|
|
|
out.shape # [3, 5, 10]
|
|
|
|
|
attn_scores.shape # [3, 5, 6]
|
|
|
|
|
"""
|
|
|
|
@ -227,19 +227,30 @@ def scaled_dot_product_attention(queries,
|
|
|
|
|
"by the number of attention heads (%d)." %
|
|
|
|
|
(values.shape[-1], num_heads))
|
|
|
|
|
|
|
|
|
|
def __compute_qkv(queries, keys, values, num_heads):
|
|
|
|
|
if num_heads == 1:
|
|
|
|
|
return queries, keys, values
|
|
|
|
|
|
|
|
|
|
q = layers.fc(input=queries, size=queries.shape[-1], num_flatten_dims=2)
|
|
|
|
|
k = layers.fc(input=keys, size=keys.shape[-1], num_flatten_dims=2)
|
|
|
|
|
v = layers.fc(input=values, size=values.shape[-1], num_flatten_dims=2)
|
|
|
|
|
return q, k, v
|
|
|
|
|
|
|
|
|
|
def __split_heads(x, num_heads):
|
|
|
|
|
"""
|
|
|
|
|
Reshape the last dimension of inpunt tensor x so that it becomes two
|
|
|
|
|
dimensions.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
x(Tensor): a 3-D input Tensor.
|
|
|
|
|
num_heads(int): The number of heads.
|
|
|
|
|
x(Tensor): a 3-D input Tensor.
|
|
|
|
|
num_heads(int): The number of heads.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
a Tensor with shape [..., n, m/n]
|
|
|
|
|
Tensor: a Tensor with shape [..., n, m/num_heads], where m is size
|
|
|
|
|
of the last dimension of x.
|
|
|
|
|
"""
|
|
|
|
|
if num_heads == 1: return x
|
|
|
|
|
if num_heads == 1:
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
hidden_size = x.shape[-1]
|
|
|
|
|
# reshape the 3-D input: [batch_size, max_sequence_length, hidden_dim]
|
|
|
|
@ -254,6 +265,19 @@ def scaled_dot_product_attention(queries,
|
|
|
|
|
return layers.transpose(x=reshaped, perm=[0, 2, 1, 3])
|
|
|
|
|
|
|
|
|
|
def __combine_heads(x):
|
|
|
|
|
"""
|
|
|
|
|
Reshape the last two dimensions of inpunt tensor x so that it becomes
|
|
|
|
|
one dimension.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
x(Tensor): a 4-D input Tensor with shape
|
|
|
|
|
[bs, num_heads, max_sequence_length, hidden_dim].
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
Tensor: a Tensor with shape
|
|
|
|
|
[bs, max_sequence_length, num_heads * hidden_dim].
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
if len(x.shape) == 3: return
|
|
|
|
|
if len(x.shape) != 4:
|
|
|
|
|
raise ValueError("Input(x) should be a 4-D Tensor.")
|
|
|
|
@ -266,9 +290,11 @@ def scaled_dot_product_attention(queries,
|
|
|
|
|
trans_x.shape[2] * trans_x.shape[3]
|
|
|
|
|
]))
|
|
|
|
|
|
|
|
|
|
q = __split_heads(queries, num_heads)
|
|
|
|
|
k = __split_heads(keys, num_heads)
|
|
|
|
|
v = __split_heads(values, num_heads)
|
|
|
|
|
q, k, v = __compute_qkv(queries, keys, values, num_heads)
|
|
|
|
|
|
|
|
|
|
q = __split_heads(q, num_heads)
|
|
|
|
|
k = __split_heads(k, num_heads)
|
|
|
|
|
v = __split_heads(v, num_heads)
|
|
|
|
|
|
|
|
|
|
key_dim_per_head = keys.shape[-1] // num_heads
|
|
|
|
|
scaled_q = layers.scale(x=q, scale=key_dim_per_head**-0.5)
|
|
|
|
|