Optimization of Transformer API (#30957)

* Support 'bool' and 'int' for attention mask.

* Update docs.

* Add unittest for Transformer.

* fix bugs.
revert-31068-fix_conv3d_windows
xiemoyuan 4 years ago committed by GitHub
parent ee1801c1ad
commit edacb6293c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -51,6 +51,7 @@ def generate_query_key_value_cache(self_attention,
num_heads,
query_length,
embed_dim,
attn_mask_type,
key_length=None,
value_length=None,
kdim=None,
@ -58,8 +59,14 @@ def generate_query_key_value_cache(self_attention,
cache=None):
query = np.random.rand(batch_size, query_length,
embed_dim).astype("float32")
attn_mask = np.zeros((batch_size, num_heads, query_length, key_length))
attn_mask[0][0][0][0] = -1e9
attn_mask = np.ones(
(batch_size, num_heads, query_length, key_length), dtype=attn_mask_type)
if attn_mask_type == 'int64':
attn_mask = np.tril(attn_mask)
elif attn_mask_type == 'float64':
attn_mask = (np.tril(attn_mask) - 1.0) * 1e9
else:
raise ValueError("'attn_mask_type' should be 'int64' or 'float64'.")
head_dim = embed_dim // num_heads
if self_attention:
@ -115,6 +122,10 @@ def scaled_dot_product_attention(q, k, v, d_key, attn_mask, multi_head_attn):
k = k.transpose([0, 1, 3, 2])
qkt = batch_matmul(q, k / np.sqrt(d_key, dtype=np.float64))
if attn_mask is not None:
if attn_mask.dtype.name == 'int64':
attn_mask = (attn_mask.astype(qkt.dtype) - 1.0) * 1e9
else:
attn_mask = attn_mask.astype(qkt.dtype)
qkt += attn_mask
weight = softmax(qkt)
attn_heads = batch_matmul(weight, v)
@ -219,53 +230,57 @@ class TestTransformer(unittest.TestCase):
# generate params for multi_head_attention
batch_size, query_length, key_length, value_length, embed_dim, kdim, vdim, num_heads, attn_dropout = generate_basic_params(
"attn", self_attention)
query, key, value, attn_mask, cache_dict = generate_query_key_value_cache(
self_attention, batch_size, num_heads, query_length,
embed_dim, key_length, value_length, kdim, vdim, cache)
if cache and self_attention:
attn_mask = np.concatenate((attn_mask, attn_mask), axis=3)
need_weight, param_attr, bias_attr = False, None, None
# call paddle's function
multi_head_attn = MultiHeadAttention(
embed_dim, num_heads, attn_dropout, kdim, vdim, need_weight,
param_attr, bias_attr)
# construct cache object
cache_obj = None
if cache_dict:
if 'k' and 'v' in cache_dict:
cache_obj = multi_head_attn.Cache(
paddle.to_tensor(cache_dict['k']),
paddle.to_tensor(cache_dict['v']))
elif 'static_k' and 'static_v' in cache_dict:
cache_obj = multi_head_attn.StaticCache(
paddle.to_tensor(cache_dict['static_k']),
paddle.to_tensor(cache_dict['static_v']))
if attn_mask is not None:
attn_output = multi_head_attn(
paddle.to_tensor(query),
paddle.to_tensor(key),
paddle.to_tensor(value),
paddle.to_tensor(attn_mask), cache_obj)
else:
attn_output = multi_head_attn(
paddle.to_tensor(query),
paddle.to_tensor(key),
paddle.to_tensor(value), attn_mask, cache_obj)
attn_output = attn_output[0] if cache_dict else attn_output
# implementation by numpy
# compute q, k, v
q, k, v, _ = prepare_qkv(query, key, value, num_heads,
embed_dim, self_attention,
multi_head_attn, cache_dict)
# scale dot product attention
attn_heads = scaled_dot_product_attention(
q, k, v, embed_dim // num_heads, attn_mask, multi_head_attn)
out_proj_weight = multi_head_attn.out_proj.weight.numpy()
reference = fc(attn_heads, out_proj_weight)
np.testing.assert_allclose(
attn_output.numpy(), reference, atol=1e-6)
for attn_mask_type in ['int64', 'float64']:
query, key, value, attn_mask, cache_dict = generate_query_key_value_cache(
self_attention, batch_size, num_heads, query_length,
embed_dim, attn_mask_type, key_length, value_length,
kdim, vdim, cache)
if cache and self_attention:
attn_mask = np.concatenate(
(attn_mask, attn_mask), axis=3)
need_weight, param_attr, bias_attr = False, None, None
# call paddle's function
multi_head_attn = MultiHeadAttention(
embed_dim, num_heads, attn_dropout, kdim, vdim,
need_weight, param_attr, bias_attr)
# construct cache object
cache_obj = None
if cache_dict:
if 'k' and 'v' in cache_dict:
cache_obj = multi_head_attn.Cache(
paddle.to_tensor(cache_dict['k']),
paddle.to_tensor(cache_dict['v']))
elif 'static_k' and 'static_v' in cache_dict:
cache_obj = multi_head_attn.StaticCache(
paddle.to_tensor(cache_dict['static_k']),
paddle.to_tensor(cache_dict['static_v']))
if attn_mask is not None:
attn_output = multi_head_attn(
paddle.to_tensor(query),
paddle.to_tensor(key),
paddle.to_tensor(value),
paddle.to_tensor(attn_mask), cache_obj)
else:
attn_output = multi_head_attn(
paddle.to_tensor(query),
paddle.to_tensor(key),
paddle.to_tensor(value), attn_mask, cache_obj)
attn_output = attn_output[0] if cache_dict else attn_output
# implementation by numpy
# compute q, k, v
q, k, v, _ = prepare_qkv(query, key, value, num_heads,
embed_dim, self_attention,
multi_head_attn, cache_dict)
# scale dot product attention
attn_heads = scaled_dot_product_attention(
q, k, v, embed_dim // num_heads, attn_mask,
multi_head_attn)
out_proj_weight = multi_head_attn.out_proj.weight.numpy()
reference = fc(attn_heads, out_proj_weight)
np.testing.assert_allclose(
attn_output.numpy(), reference, atol=1e-6)
multihead_attention_test_helper(True, True)
multihead_attention_test_helper(True, False)

File diff suppressed because it is too large Load Diff
Loading…
Cancel
Save