Refine multi_head_attention

release/0.11.0
ranqiu 7 years ago
parent d29901b825
commit f22402933e

@ -1586,9 +1586,9 @@ def multi_head_attention(query,
value_proj, offset=value_proj_size * i, size=value_proj_size) value_proj, offset=value_proj_size * i, size=value_proj_size)
if attention_type == 'dot-product attention': if attention_type == 'dot-product attention':
m = linear_comb_layer( m = dot_prod_layer(
weights=sub_query_proj, input1=sub_query_proj,
vectors=sub_key_proj, input2=sub_key_proj,
name='%s_dot-product_%d' % (name, i)) name='%s_dot-product_%d' % (name, i))
m = slope_intercept_layer( m = slope_intercept_layer(
input=m, input=m,

Loading…
Cancel
Save