|
|
|
@ -11,7 +11,7 @@
|
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
|
# limitations under the License.
|
|
|
|
|
|
|
|
|
|
import math
|
|
|
|
|
|
|
|
|
|
from activations import LinearActivation, ReluActivation, SoftmaxActivation, \
|
|
|
|
|
IdentityActivation, TanhActivation, SequenceSoftmaxActivation
|
|
|
|
@ -26,9 +26,9 @@ __all__ = [
|
|
|
|
|
'sequence_conv_pool', 'simple_lstm', "simple_img_conv_pool",
|
|
|
|
|
"img_conv_bn_pool", 'lstmemory_group', 'lstmemory_unit', 'small_vgg',
|
|
|
|
|
'img_conv_group', 'vgg_16_network', 'gru_unit', 'gru_group', 'simple_gru',
|
|
|
|
|
'simple_attention', 'dot_product_attention', 'simple_gru2',
|
|
|
|
|
'bidirectional_gru', 'text_conv_pool', 'bidirectional_lstm', 'inputs',
|
|
|
|
|
'outputs'
|
|
|
|
|
'simple_attention', 'dot_product_attention', 'multi_head_attention',
|
|
|
|
|
'simple_gru2', 'bidirectional_gru', 'text_conv_pool', 'bidirectional_lstm',
|
|
|
|
|
'inputs', 'outputs'
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
######################################################
|
|
|
|
@ -1476,10 +1476,8 @@ def dot_product_attention(encoded_sequence,
|
|
|
|
|
expand_as=encoded_sequence,
|
|
|
|
|
name='%s_expand' % name)
|
|
|
|
|
|
|
|
|
|
m = linear_comb_layer(
|
|
|
|
|
weights=expanded,
|
|
|
|
|
vectors=encoded_sequence,
|
|
|
|
|
name='%s_dot-product' % name)
|
|
|
|
|
m = dot_prod_layer(
|
|
|
|
|
input1=expanded, input2=encoded_sequence, name='%s_dot-product' % name)
|
|
|
|
|
|
|
|
|
|
attention_weight = fc_layer(
|
|
|
|
|
input=m,
|
|
|
|
@ -1498,6 +1496,134 @@ def dot_product_attention(encoded_sequence,
|
|
|
|
|
input=scaled, pooling_type=SumPooling(), name="%s_pooling" % name)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@wrap_name_default()
|
|
|
|
|
def multi_head_attention(query,
|
|
|
|
|
key,
|
|
|
|
|
value,
|
|
|
|
|
key_proj_size,
|
|
|
|
|
value_proj_size,
|
|
|
|
|
head_num,
|
|
|
|
|
attention_type,
|
|
|
|
|
softmax_param_attr=None,
|
|
|
|
|
name=None):
|
|
|
|
|
"""
|
|
|
|
|
Calculate and return a context vector with dot-product attention mechanism.
|
|
|
|
|
The dimension of the context vector equals to value_proj_size * head_num.
|
|
|
|
|
|
|
|
|
|
Please refer to **Attention Is All You Need** for more details. The link is
|
|
|
|
|
as follows:
|
|
|
|
|
https://arxiv.org/abs/1706.03762.
|
|
|
|
|
|
|
|
|
|
The example usage is:
|
|
|
|
|
|
|
|
|
|
.. code-block:: python
|
|
|
|
|
|
|
|
|
|
context = multi_head_attention(query=decoder_state,
|
|
|
|
|
key=enc_seq,
|
|
|
|
|
value=enc_seq,
|
|
|
|
|
key_proj_size=64,
|
|
|
|
|
value_pro_size=64,
|
|
|
|
|
head_num=8,
|
|
|
|
|
attention_type='dot-product attention')
|
|
|
|
|
|
|
|
|
|
:param name: A prefix attached to the name of each layer that defined inside
|
|
|
|
|
the multi_head_attention.
|
|
|
|
|
:type name: basestring
|
|
|
|
|
:param softmax_param_attr: The parameter attribute of sequence softmax
|
|
|
|
|
that is used to produce attention weight.
|
|
|
|
|
:type softmax_param_attr: ParameterAttribute
|
|
|
|
|
:param query: query is used to calculate attention weights over values at current step.
|
|
|
|
|
:type query: LayerOutput
|
|
|
|
|
:param key: key is used to calculate the attention weight of the corresponding value.
|
|
|
|
|
:type key: LayerOutput
|
|
|
|
|
:param value: value is the sequence to be attended.
|
|
|
|
|
:type value: LayerOutput
|
|
|
|
|
:param key_proj_size: The dimension of the linear projection performed on key and query.
|
|
|
|
|
:type key_proj_size: int
|
|
|
|
|
:param value_proj_size: The dimension of the linear projection performed on value.
|
|
|
|
|
:type value_proj_size: int
|
|
|
|
|
:param head_num: The number of attention heads.
|
|
|
|
|
:type head_num: int
|
|
|
|
|
:param attention_type: The type of the attention mechanism used in each attention
|
|
|
|
|
heads. Now, we only support scaled dot-product attention and
|
|
|
|
|
additive attention.
|
|
|
|
|
:type attention_type: basestring
|
|
|
|
|
:return: The context vector.
|
|
|
|
|
:rtype: LayerOutput
|
|
|
|
|
"""
|
|
|
|
|
assert attention_type in ['dot-product attention', 'additive attention']
|
|
|
|
|
|
|
|
|
|
with mixed_layer(
|
|
|
|
|
size=key_proj_size * head_num,
|
|
|
|
|
name='%s_query_proj' % name) as query_proj:
|
|
|
|
|
query_proj += full_matrix_projection(query)
|
|
|
|
|
query_proj = expand_layer(input=query_proj, expand_as=key)
|
|
|
|
|
|
|
|
|
|
with mixed_layer(
|
|
|
|
|
size=key_proj_size * head_num,
|
|
|
|
|
name='%s_key_proj' % name) as key_proj:
|
|
|
|
|
key_proj += full_matrix_projection(key)
|
|
|
|
|
|
|
|
|
|
with mixed_layer(
|
|
|
|
|
size=value_proj_size * head_num,
|
|
|
|
|
name='%s_value_proj' % name) as value_proj:
|
|
|
|
|
value_proj += full_matrix_projection(value)
|
|
|
|
|
|
|
|
|
|
head_list = []
|
|
|
|
|
for i in range(head_num):
|
|
|
|
|
with mixed_layer(size=key_proj_size) as sub_query_proj:
|
|
|
|
|
sub_query_proj += identity_projection(
|
|
|
|
|
query_proj, offset=key_proj_size * i, size=key_proj_size)
|
|
|
|
|
|
|
|
|
|
with mixed_layer(size=key_proj_size) as sub_key_proj:
|
|
|
|
|
sub_key_proj += identity_projection(
|
|
|
|
|
key_proj, offset=key_proj_size * i, size=key_proj_size)
|
|
|
|
|
|
|
|
|
|
with mixed_layer(size=value_proj_size) as sub_value_proj:
|
|
|
|
|
sub_value_proj += identity_projection(
|
|
|
|
|
value_proj, offset=value_proj_size * i, size=value_proj_size)
|
|
|
|
|
|
|
|
|
|
if attention_type == 'dot-product attention':
|
|
|
|
|
m = dot_prod_layer(
|
|
|
|
|
input1=sub_query_proj,
|
|
|
|
|
input2=sub_key_proj,
|
|
|
|
|
name='%s_dot-product_%d' % (name, i))
|
|
|
|
|
m = slope_intercept_layer(
|
|
|
|
|
input=m,
|
|
|
|
|
slope=math.sqrt(1.0 / key_proj_size),
|
|
|
|
|
name='%s_dot-product_scaling_%d' % (name, i))
|
|
|
|
|
else:
|
|
|
|
|
with mixed_layer(
|
|
|
|
|
size=key_proj_size,
|
|
|
|
|
act=TanhActivation(),
|
|
|
|
|
name='%s_combine_%d' % (name, i)) as m:
|
|
|
|
|
m += identity_projection(sub_query_proj)
|
|
|
|
|
m += identity_projection(sub_key_proj)
|
|
|
|
|
|
|
|
|
|
attention_weight = fc_layer(
|
|
|
|
|
input=m,
|
|
|
|
|
size=1,
|
|
|
|
|
act=SequenceSoftmaxActivation(),
|
|
|
|
|
param_attr=softmax_param_attr,
|
|
|
|
|
name="%s_softmax_%d" % (name, i),
|
|
|
|
|
bias_attr=False)
|
|
|
|
|
|
|
|
|
|
scaled = scaling_layer(
|
|
|
|
|
weight=attention_weight,
|
|
|
|
|
input=sub_value_proj,
|
|
|
|
|
name='%s_scaling_%d' % (name, i))
|
|
|
|
|
head = pooling_layer(
|
|
|
|
|
input=scaled,
|
|
|
|
|
pooling_type=SumPooling(),
|
|
|
|
|
name="%s_pooling_%d" % (name, i))
|
|
|
|
|
|
|
|
|
|
head_list.append(head)
|
|
|
|
|
|
|
|
|
|
attended = concat_layer(head_list)
|
|
|
|
|
|
|
|
|
|
return attended
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def inputs(layers, *args):
|
|
|
|
|
"""
|
|
|
|
|
Declare the inputs of network. The order of input should be as same as
|
|
|
|
|