Optimization of Transformer API (#30957)

* Support 'bool' and 'int' for attention mask.

* Update docs.

* Add unittest for Transformer.

* fix bugs.
revert-31068-fix_conv3d_windows
xiemoyuan 4 years ago committed by GitHub
parent ee1801c1ad
commit edacb6293c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -51,6 +51,7 @@ def generate_query_key_value_cache(self_attention,
num_heads, num_heads,
query_length, query_length,
embed_dim, embed_dim,
attn_mask_type,
key_length=None, key_length=None,
value_length=None, value_length=None,
kdim=None, kdim=None,
@ -58,8 +59,14 @@ def generate_query_key_value_cache(self_attention,
cache=None): cache=None):
query = np.random.rand(batch_size, query_length, query = np.random.rand(batch_size, query_length,
embed_dim).astype("float32") embed_dim).astype("float32")
attn_mask = np.zeros((batch_size, num_heads, query_length, key_length)) attn_mask = np.ones(
attn_mask[0][0][0][0] = -1e9 (batch_size, num_heads, query_length, key_length), dtype=attn_mask_type)
if attn_mask_type == 'int64':
attn_mask = np.tril(attn_mask)
elif attn_mask_type == 'float64':
attn_mask = (np.tril(attn_mask) - 1.0) * 1e9
else:
raise ValueError("'attn_mask_type' should be 'int64' or 'float64'.")
head_dim = embed_dim // num_heads head_dim = embed_dim // num_heads
if self_attention: if self_attention:
@ -115,6 +122,10 @@ def scaled_dot_product_attention(q, k, v, d_key, attn_mask, multi_head_attn):
k = k.transpose([0, 1, 3, 2]) k = k.transpose([0, 1, 3, 2])
qkt = batch_matmul(q, k / np.sqrt(d_key, dtype=np.float64)) qkt = batch_matmul(q, k / np.sqrt(d_key, dtype=np.float64))
if attn_mask is not None: if attn_mask is not None:
if attn_mask.dtype.name == 'int64':
attn_mask = (attn_mask.astype(qkt.dtype) - 1.0) * 1e9
else:
attn_mask = attn_mask.astype(qkt.dtype)
qkt += attn_mask qkt += attn_mask
weight = softmax(qkt) weight = softmax(qkt)
attn_heads = batch_matmul(weight, v) attn_heads = batch_matmul(weight, v)
@ -219,53 +230,57 @@ class TestTransformer(unittest.TestCase):
# generate params for multi_head_attention # generate params for multi_head_attention
batch_size, query_length, key_length, value_length, embed_dim, kdim, vdim, num_heads, attn_dropout = generate_basic_params( batch_size, query_length, key_length, value_length, embed_dim, kdim, vdim, num_heads, attn_dropout = generate_basic_params(
"attn", self_attention) "attn", self_attention)
query, key, value, attn_mask, cache_dict = generate_query_key_value_cache( for attn_mask_type in ['int64', 'float64']:
self_attention, batch_size, num_heads, query_length, query, key, value, attn_mask, cache_dict = generate_query_key_value_cache(
embed_dim, key_length, value_length, kdim, vdim, cache) self_attention, batch_size, num_heads, query_length,
if cache and self_attention: embed_dim, attn_mask_type, key_length, value_length,
attn_mask = np.concatenate((attn_mask, attn_mask), axis=3) kdim, vdim, cache)
need_weight, param_attr, bias_attr = False, None, None if cache and self_attention:
# call paddle's function attn_mask = np.concatenate(
multi_head_attn = MultiHeadAttention( (attn_mask, attn_mask), axis=3)
embed_dim, num_heads, attn_dropout, kdim, vdim, need_weight, need_weight, param_attr, bias_attr = False, None, None
param_attr, bias_attr) # call paddle's function
# construct cache object multi_head_attn = MultiHeadAttention(
cache_obj = None embed_dim, num_heads, attn_dropout, kdim, vdim,
if cache_dict: need_weight, param_attr, bias_attr)
if 'k' and 'v' in cache_dict: # construct cache object
cache_obj = multi_head_attn.Cache( cache_obj = None
paddle.to_tensor(cache_dict['k']), if cache_dict:
paddle.to_tensor(cache_dict['v'])) if 'k' and 'v' in cache_dict:
elif 'static_k' and 'static_v' in cache_dict: cache_obj = multi_head_attn.Cache(
cache_obj = multi_head_attn.StaticCache( paddle.to_tensor(cache_dict['k']),
paddle.to_tensor(cache_dict['static_k']), paddle.to_tensor(cache_dict['v']))
paddle.to_tensor(cache_dict['static_v'])) elif 'static_k' and 'static_v' in cache_dict:
if attn_mask is not None: cache_obj = multi_head_attn.StaticCache(
attn_output = multi_head_attn( paddle.to_tensor(cache_dict['static_k']),
paddle.to_tensor(query), paddle.to_tensor(cache_dict['static_v']))
paddle.to_tensor(key), if attn_mask is not None:
paddle.to_tensor(value), attn_output = multi_head_attn(
paddle.to_tensor(attn_mask), cache_obj) paddle.to_tensor(query),
else: paddle.to_tensor(key),
attn_output = multi_head_attn( paddle.to_tensor(value),
paddle.to_tensor(query), paddle.to_tensor(attn_mask), cache_obj)
paddle.to_tensor(key), else:
paddle.to_tensor(value), attn_mask, cache_obj) attn_output = multi_head_attn(
attn_output = attn_output[0] if cache_dict else attn_output paddle.to_tensor(query),
paddle.to_tensor(key),
# implementation by numpy paddle.to_tensor(value), attn_mask, cache_obj)
# compute q, k, v attn_output = attn_output[0] if cache_dict else attn_output
q, k, v, _ = prepare_qkv(query, key, value, num_heads,
embed_dim, self_attention, # implementation by numpy
multi_head_attn, cache_dict) # compute q, k, v
# scale dot product attention q, k, v, _ = prepare_qkv(query, key, value, num_heads,
attn_heads = scaled_dot_product_attention( embed_dim, self_attention,
q, k, v, embed_dim // num_heads, attn_mask, multi_head_attn) multi_head_attn, cache_dict)
out_proj_weight = multi_head_attn.out_proj.weight.numpy() # scale dot product attention
reference = fc(attn_heads, out_proj_weight) attn_heads = scaled_dot_product_attention(
q, k, v, embed_dim // num_heads, attn_mask,
np.testing.assert_allclose( multi_head_attn)
attn_output.numpy(), reference, atol=1e-6) out_proj_weight = multi_head_attn.out_proj.weight.numpy()
reference = fc(attn_heads, out_proj_weight)
np.testing.assert_allclose(
attn_output.numpy(), reference, atol=1e-6)
multihead_attention_test_helper(True, True) multihead_attention_test_helper(True, True)
multihead_attention_test_helper(True, False) multihead_attention_test_helper(True, False)

File diff suppressed because it is too large Load Diff
Loading…
Cancel
Save