|
|
@ -51,6 +51,7 @@ def generate_query_key_value_cache(self_attention,
|
|
|
|
num_heads,
|
|
|
|
num_heads,
|
|
|
|
query_length,
|
|
|
|
query_length,
|
|
|
|
embed_dim,
|
|
|
|
embed_dim,
|
|
|
|
|
|
|
|
attn_mask_type,
|
|
|
|
key_length=None,
|
|
|
|
key_length=None,
|
|
|
|
value_length=None,
|
|
|
|
value_length=None,
|
|
|
|
kdim=None,
|
|
|
|
kdim=None,
|
|
|
@ -58,8 +59,14 @@ def generate_query_key_value_cache(self_attention,
|
|
|
|
cache=None):
|
|
|
|
cache=None):
|
|
|
|
query = np.random.rand(batch_size, query_length,
|
|
|
|
query = np.random.rand(batch_size, query_length,
|
|
|
|
embed_dim).astype("float32")
|
|
|
|
embed_dim).astype("float32")
|
|
|
|
attn_mask = np.zeros((batch_size, num_heads, query_length, key_length))
|
|
|
|
attn_mask = np.ones(
|
|
|
|
attn_mask[0][0][0][0] = -1e9
|
|
|
|
(batch_size, num_heads, query_length, key_length), dtype=attn_mask_type)
|
|
|
|
|
|
|
|
if attn_mask_type == 'int64':
|
|
|
|
|
|
|
|
attn_mask = np.tril(attn_mask)
|
|
|
|
|
|
|
|
elif attn_mask_type == 'float64':
|
|
|
|
|
|
|
|
attn_mask = (np.tril(attn_mask) - 1.0) * 1e9
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
raise ValueError("'attn_mask_type' should be 'int64' or 'float64'.")
|
|
|
|
|
|
|
|
|
|
|
|
head_dim = embed_dim // num_heads
|
|
|
|
head_dim = embed_dim // num_heads
|
|
|
|
if self_attention:
|
|
|
|
if self_attention:
|
|
|
@ -115,6 +122,10 @@ def scaled_dot_product_attention(q, k, v, d_key, attn_mask, multi_head_attn):
|
|
|
|
k = k.transpose([0, 1, 3, 2])
|
|
|
|
k = k.transpose([0, 1, 3, 2])
|
|
|
|
qkt = batch_matmul(q, k / np.sqrt(d_key, dtype=np.float64))
|
|
|
|
qkt = batch_matmul(q, k / np.sqrt(d_key, dtype=np.float64))
|
|
|
|
if attn_mask is not None:
|
|
|
|
if attn_mask is not None:
|
|
|
|
|
|
|
|
if attn_mask.dtype.name == 'int64':
|
|
|
|
|
|
|
|
attn_mask = (attn_mask.astype(qkt.dtype) - 1.0) * 1e9
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
attn_mask = attn_mask.astype(qkt.dtype)
|
|
|
|
qkt += attn_mask
|
|
|
|
qkt += attn_mask
|
|
|
|
weight = softmax(qkt)
|
|
|
|
weight = softmax(qkt)
|
|
|
|
attn_heads = batch_matmul(weight, v)
|
|
|
|
attn_heads = batch_matmul(weight, v)
|
|
|
@ -219,53 +230,57 @@ class TestTransformer(unittest.TestCase):
|
|
|
|
# generate params for multi_head_attention
|
|
|
|
# generate params for multi_head_attention
|
|
|
|
batch_size, query_length, key_length, value_length, embed_dim, kdim, vdim, num_heads, attn_dropout = generate_basic_params(
|
|
|
|
batch_size, query_length, key_length, value_length, embed_dim, kdim, vdim, num_heads, attn_dropout = generate_basic_params(
|
|
|
|
"attn", self_attention)
|
|
|
|
"attn", self_attention)
|
|
|
|
query, key, value, attn_mask, cache_dict = generate_query_key_value_cache(
|
|
|
|
for attn_mask_type in ['int64', 'float64']:
|
|
|
|
self_attention, batch_size, num_heads, query_length,
|
|
|
|
query, key, value, attn_mask, cache_dict = generate_query_key_value_cache(
|
|
|
|
embed_dim, key_length, value_length, kdim, vdim, cache)
|
|
|
|
self_attention, batch_size, num_heads, query_length,
|
|
|
|
if cache and self_attention:
|
|
|
|
embed_dim, attn_mask_type, key_length, value_length,
|
|
|
|
attn_mask = np.concatenate((attn_mask, attn_mask), axis=3)
|
|
|
|
kdim, vdim, cache)
|
|
|
|
need_weight, param_attr, bias_attr = False, None, None
|
|
|
|
if cache and self_attention:
|
|
|
|
# call paddle's function
|
|
|
|
attn_mask = np.concatenate(
|
|
|
|
multi_head_attn = MultiHeadAttention(
|
|
|
|
(attn_mask, attn_mask), axis=3)
|
|
|
|
embed_dim, num_heads, attn_dropout, kdim, vdim, need_weight,
|
|
|
|
need_weight, param_attr, bias_attr = False, None, None
|
|
|
|
param_attr, bias_attr)
|
|
|
|
# call paddle's function
|
|
|
|
# construct cache object
|
|
|
|
multi_head_attn = MultiHeadAttention(
|
|
|
|
cache_obj = None
|
|
|
|
embed_dim, num_heads, attn_dropout, kdim, vdim,
|
|
|
|
if cache_dict:
|
|
|
|
need_weight, param_attr, bias_attr)
|
|
|
|
if 'k' and 'v' in cache_dict:
|
|
|
|
# construct cache object
|
|
|
|
cache_obj = multi_head_attn.Cache(
|
|
|
|
cache_obj = None
|
|
|
|
paddle.to_tensor(cache_dict['k']),
|
|
|
|
if cache_dict:
|
|
|
|
paddle.to_tensor(cache_dict['v']))
|
|
|
|
if 'k' and 'v' in cache_dict:
|
|
|
|
elif 'static_k' and 'static_v' in cache_dict:
|
|
|
|
cache_obj = multi_head_attn.Cache(
|
|
|
|
cache_obj = multi_head_attn.StaticCache(
|
|
|
|
paddle.to_tensor(cache_dict['k']),
|
|
|
|
paddle.to_tensor(cache_dict['static_k']),
|
|
|
|
paddle.to_tensor(cache_dict['v']))
|
|
|
|
paddle.to_tensor(cache_dict['static_v']))
|
|
|
|
elif 'static_k' and 'static_v' in cache_dict:
|
|
|
|
if attn_mask is not None:
|
|
|
|
cache_obj = multi_head_attn.StaticCache(
|
|
|
|
attn_output = multi_head_attn(
|
|
|
|
paddle.to_tensor(cache_dict['static_k']),
|
|
|
|
paddle.to_tensor(query),
|
|
|
|
paddle.to_tensor(cache_dict['static_v']))
|
|
|
|
paddle.to_tensor(key),
|
|
|
|
if attn_mask is not None:
|
|
|
|
paddle.to_tensor(value),
|
|
|
|
attn_output = multi_head_attn(
|
|
|
|
paddle.to_tensor(attn_mask), cache_obj)
|
|
|
|
paddle.to_tensor(query),
|
|
|
|
else:
|
|
|
|
paddle.to_tensor(key),
|
|
|
|
attn_output = multi_head_attn(
|
|
|
|
paddle.to_tensor(value),
|
|
|
|
paddle.to_tensor(query),
|
|
|
|
paddle.to_tensor(attn_mask), cache_obj)
|
|
|
|
paddle.to_tensor(key),
|
|
|
|
else:
|
|
|
|
paddle.to_tensor(value), attn_mask, cache_obj)
|
|
|
|
attn_output = multi_head_attn(
|
|
|
|
attn_output = attn_output[0] if cache_dict else attn_output
|
|
|
|
paddle.to_tensor(query),
|
|
|
|
|
|
|
|
paddle.to_tensor(key),
|
|
|
|
# implementation by numpy
|
|
|
|
paddle.to_tensor(value), attn_mask, cache_obj)
|
|
|
|
# compute q, k, v
|
|
|
|
attn_output = attn_output[0] if cache_dict else attn_output
|
|
|
|
q, k, v, _ = prepare_qkv(query, key, value, num_heads,
|
|
|
|
|
|
|
|
embed_dim, self_attention,
|
|
|
|
# implementation by numpy
|
|
|
|
multi_head_attn, cache_dict)
|
|
|
|
# compute q, k, v
|
|
|
|
# scale dot product attention
|
|
|
|
q, k, v, _ = prepare_qkv(query, key, value, num_heads,
|
|
|
|
attn_heads = scaled_dot_product_attention(
|
|
|
|
embed_dim, self_attention,
|
|
|
|
q, k, v, embed_dim // num_heads, attn_mask, multi_head_attn)
|
|
|
|
multi_head_attn, cache_dict)
|
|
|
|
out_proj_weight = multi_head_attn.out_proj.weight.numpy()
|
|
|
|
# scale dot product attention
|
|
|
|
reference = fc(attn_heads, out_proj_weight)
|
|
|
|
attn_heads = scaled_dot_product_attention(
|
|
|
|
|
|
|
|
q, k, v, embed_dim // num_heads, attn_mask,
|
|
|
|
np.testing.assert_allclose(
|
|
|
|
multi_head_attn)
|
|
|
|
attn_output.numpy(), reference, atol=1e-6)
|
|
|
|
out_proj_weight = multi_head_attn.out_proj.weight.numpy()
|
|
|
|
|
|
|
|
reference = fc(attn_heads, out_proj_weight)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
np.testing.assert_allclose(
|
|
|
|
|
|
|
|
attn_output.numpy(), reference, atol=1e-6)
|
|
|
|
|
|
|
|
|
|
|
|
multihead_attention_test_helper(True, True)
|
|
|
|
multihead_attention_test_helper(True, True)
|
|
|
|
multihead_attention_test_helper(True, False)
|
|
|
|
multihead_attention_test_helper(True, False)
|
|
|
|