Refine multi-head attention

release/0.11.0
ranqiu 8 years ago
parent 947c528508
commit 7461b35977

@ -1557,15 +1557,15 @@ def multi_head_attention(query,
for i in range(head_num):
with mixed_layer(size=key_proj_size) as sub_query_proj:
sub_query_proj += identity_projection(
query_proj, offset=key_proj_size * i)
query_proj, offset=key_proj_size * i, size=key_proj_size)
with mixed_layer(size=key_proj_size) as sub_key_proj:
sub_key_proj += identity_projection(
key_proj, offset=key_proj_size * i)
key_proj, offset=key_proj_size * i, size=key_proj_size)
with mixed_layer(size=value_proj_size) as sub_value_proj:
sub_value_proj += identity_projection(
value_proj, offset=value_proj_size * i)
value_proj, offset=value_proj_size * i, size=value_proj_size)
if attention_type == 'dot-product attention':
m = linear_comb_layer(
@ -1603,11 +1603,7 @@ def multi_head_attention(query,
head_list.append(head)
multi_head = concat_layer(head_list)
with mixed_layer(
size=value_proj_size * head_num, name='%s_proj' % name) as attended:
attended += full_matrix_projection(multi_head)
attended = concat_layer(head_list)
return attended

Loading…
Cancel
Save