Optimization of Transformer API (#30957)

* Support 'bool' and 'int' for attention mask.

* Update docs.

* Add unittest for Transformer.

* fix bugs.
revert-31068-fix_conv3d_windows
xiemoyuan 4 years ago committed by GitHub
parent ee1801c1ad
commit edacb6293c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -51,6 +51,7 @@ def generate_query_key_value_cache(self_attention,
num_heads, num_heads,
query_length, query_length,
embed_dim, embed_dim,
attn_mask_type,
key_length=None, key_length=None,
value_length=None, value_length=None,
kdim=None, kdim=None,
@ -58,8 +59,14 @@ def generate_query_key_value_cache(self_attention,
cache=None): cache=None):
query = np.random.rand(batch_size, query_length, query = np.random.rand(batch_size, query_length,
embed_dim).astype("float32") embed_dim).astype("float32")
attn_mask = np.zeros((batch_size, num_heads, query_length, key_length)) attn_mask = np.ones(
attn_mask[0][0][0][0] = -1e9 (batch_size, num_heads, query_length, key_length), dtype=attn_mask_type)
if attn_mask_type == 'int64':
attn_mask = np.tril(attn_mask)
elif attn_mask_type == 'float64':
attn_mask = (np.tril(attn_mask) - 1.0) * 1e9
else:
raise ValueError("'attn_mask_type' should be 'int64' or 'float64'.")
head_dim = embed_dim // num_heads head_dim = embed_dim // num_heads
if self_attention: if self_attention:
@ -115,6 +122,10 @@ def scaled_dot_product_attention(q, k, v, d_key, attn_mask, multi_head_attn):
k = k.transpose([0, 1, 3, 2]) k = k.transpose([0, 1, 3, 2])
qkt = batch_matmul(q, k / np.sqrt(d_key, dtype=np.float64)) qkt = batch_matmul(q, k / np.sqrt(d_key, dtype=np.float64))
if attn_mask is not None: if attn_mask is not None:
if attn_mask.dtype.name == 'int64':
attn_mask = (attn_mask.astype(qkt.dtype) - 1.0) * 1e9
else:
attn_mask = attn_mask.astype(qkt.dtype)
qkt += attn_mask qkt += attn_mask
weight = softmax(qkt) weight = softmax(qkt)
attn_heads = batch_matmul(weight, v) attn_heads = batch_matmul(weight, v)
@ -219,16 +230,19 @@ class TestTransformer(unittest.TestCase):
# generate params for multi_head_attention # generate params for multi_head_attention
batch_size, query_length, key_length, value_length, embed_dim, kdim, vdim, num_heads, attn_dropout = generate_basic_params( batch_size, query_length, key_length, value_length, embed_dim, kdim, vdim, num_heads, attn_dropout = generate_basic_params(
"attn", self_attention) "attn", self_attention)
for attn_mask_type in ['int64', 'float64']:
query, key, value, attn_mask, cache_dict = generate_query_key_value_cache( query, key, value, attn_mask, cache_dict = generate_query_key_value_cache(
self_attention, batch_size, num_heads, query_length, self_attention, batch_size, num_heads, query_length,
embed_dim, key_length, value_length, kdim, vdim, cache) embed_dim, attn_mask_type, key_length, value_length,
kdim, vdim, cache)
if cache and self_attention: if cache and self_attention:
attn_mask = np.concatenate((attn_mask, attn_mask), axis=3) attn_mask = np.concatenate(
(attn_mask, attn_mask), axis=3)
need_weight, param_attr, bias_attr = False, None, None need_weight, param_attr, bias_attr = False, None, None
# call paddle's function # call paddle's function
multi_head_attn = MultiHeadAttention( multi_head_attn = MultiHeadAttention(
embed_dim, num_heads, attn_dropout, kdim, vdim, need_weight, embed_dim, num_heads, attn_dropout, kdim, vdim,
param_attr, bias_attr) need_weight, param_attr, bias_attr)
# construct cache object # construct cache object
cache_obj = None cache_obj = None
if cache_dict: if cache_dict:
@ -260,7 +274,8 @@ class TestTransformer(unittest.TestCase):
multi_head_attn, cache_dict) multi_head_attn, cache_dict)
# scale dot product attention # scale dot product attention
attn_heads = scaled_dot_product_attention( attn_heads = scaled_dot_product_attention(
q, k, v, embed_dim // num_heads, attn_mask, multi_head_attn) q, k, v, embed_dim // num_heads, attn_mask,
multi_head_attn)
out_proj_weight = multi_head_attn.out_proj.weight.numpy() out_proj_weight = multi_head_attn.out_proj.weight.numpy()
reference = fc(attn_heads, out_proj_weight) reference = fc(attn_heads, out_proj_weight)

File diff suppressed because it is too large Load Diff
Loading…
Cancel
Save