|
|
|
@ -1557,15 +1557,15 @@ def multi_head_attention(query,
|
|
|
|
|
for i in range(head_num):
|
|
|
|
|
with mixed_layer(size=key_proj_size) as sub_query_proj:
|
|
|
|
|
sub_query_proj += identity_projection(
|
|
|
|
|
query_proj, offset=key_proj_size * i)
|
|
|
|
|
query_proj, offset=key_proj_size * i, size=key_proj_size)
|
|
|
|
|
|
|
|
|
|
with mixed_layer(size=key_proj_size) as sub_key_proj:
|
|
|
|
|
sub_key_proj += identity_projection(
|
|
|
|
|
key_proj, offset=key_proj_size * i)
|
|
|
|
|
key_proj, offset=key_proj_size * i, size=key_proj_size)
|
|
|
|
|
|
|
|
|
|
with mixed_layer(size=value_proj_size) as sub_value_proj:
|
|
|
|
|
sub_value_proj += identity_projection(
|
|
|
|
|
value_proj, offset=value_proj_size * i)
|
|
|
|
|
value_proj, offset=value_proj_size * i, size=value_proj_size)
|
|
|
|
|
|
|
|
|
|
if attention_type == 'dot-product attention':
|
|
|
|
|
m = linear_comb_layer(
|
|
|
|
@ -1603,11 +1603,7 @@ def multi_head_attention(query,
|
|
|
|
|
|
|
|
|
|
head_list.append(head)
|
|
|
|
|
|
|
|
|
|
multi_head = concat_layer(head_list)
|
|
|
|
|
|
|
|
|
|
with mixed_layer(
|
|
|
|
|
size=value_proj_size * head_num, name='%s_proj' % name) as attended:
|
|
|
|
|
attended += full_matrix_projection(multi_head)
|
|
|
|
|
attended = concat_layer(head_list)
|
|
|
|
|
|
|
|
|
|
return attended
|
|
|
|
|
|
|
|
|
|