|
|
@ -25,12 +25,13 @@ __all__ = [
|
|
|
|
import copy
|
|
|
|
import copy
|
|
|
|
import collections
|
|
|
|
import collections
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from .common import Linear, Dropout
|
|
|
|
|
|
|
|
from .norm import LayerNorm
|
|
|
|
|
|
|
|
from .. import functional as F
|
|
|
|
|
|
|
|
from ... import tensor
|
|
|
|
from ...fluid import layers
|
|
|
|
from ...fluid import layers
|
|
|
|
|
|
|
|
from ...fluid.dygraph import Layer, LayerList
|
|
|
|
from ...fluid.param_attr import ParamAttr
|
|
|
|
from ...fluid.param_attr import ParamAttr
|
|
|
|
from ...fluid.dygraph import Layer, Linear, Dropout, LayerNorm, LayerList
|
|
|
|
|
|
|
|
from .. import functional as F
|
|
|
|
|
|
|
|
from ...fluid.layers import utils
|
|
|
|
|
|
|
|
from ...fluid.layers.utils import map_structure
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _convert_param_attr_to_list(param_attr, n):
|
|
|
|
def _convert_param_attr_to_list(param_attr, n):
|
|
|
@ -103,7 +104,7 @@ class MultiHeadAttention(Layer):
|
|
|
|
# self attention mask: [batch_size, num_heads, query_len, query_len]
|
|
|
|
# self attention mask: [batch_size, num_heads, query_len, query_len]
|
|
|
|
attn_mask = paddle.rand((2, 2, 4, 4))
|
|
|
|
attn_mask = paddle.rand((2, 2, 4, 4))
|
|
|
|
multi_head_attn = paddle.MultiHeadAttention(128, 2)
|
|
|
|
multi_head_attn = paddle.MultiHeadAttention(128, 2)
|
|
|
|
output = multi_head_attn(query, attn_mask=attn_mask) # [2, 4, 128]
|
|
|
|
output = multi_head_attn(query, None, None, attn_mask=attn_mask) # [2, 4, 128]
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
Cache = collections.namedtuple("Cache", ["k", "v"])
|
|
|
|
Cache = collections.namedtuple("Cache", ["k", "v"])
|
|
|
@ -176,8 +177,8 @@ class MultiHeadAttention(Layer):
|
|
|
|
and their data types are same as inputs.
|
|
|
|
and their data types are same as inputs.
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
q = self.q_proj(query)
|
|
|
|
q = self.q_proj(query)
|
|
|
|
q = layers.reshape(x=q, shape=[0, 0, self.num_heads, self.head_dim])
|
|
|
|
q = tensor.reshape(x=q, shape=[0, 0, self.num_heads, self.head_dim])
|
|
|
|
q = layers.transpose(x=q, perm=[0, 2, 1, 3])
|
|
|
|
q = tensor.transpose(x=q, perm=[0, 2, 1, 3])
|
|
|
|
|
|
|
|
|
|
|
|
if isinstance(cache, self.StaticCache):
|
|
|
|
if isinstance(cache, self.StaticCache):
|
|
|
|
# for encoder-decoder attention in inference and has cached
|
|
|
|
# for encoder-decoder attention in inference and has cached
|
|
|
@ -187,8 +188,8 @@ class MultiHeadAttention(Layer):
|
|
|
|
|
|
|
|
|
|
|
|
if isinstance(cache, self.Cache):
|
|
|
|
if isinstance(cache, self.Cache):
|
|
|
|
# for decoder self-attention in inference
|
|
|
|
# for decoder self-attention in inference
|
|
|
|
k = layers.concat([cache.k, k], axis=2)
|
|
|
|
k = tensor.concat([cache.k, k], axis=2)
|
|
|
|
v = layers.concat([cache.v, v], axis=2)
|
|
|
|
v = tensor.concat([cache.v, v], axis=2)
|
|
|
|
cache = self.Cache(k, v)
|
|
|
|
cache = self.Cache(k, v)
|
|
|
|
|
|
|
|
|
|
|
|
return (q, k, v) if cache is None else (q, k, v, cache)
|
|
|
|
return (q, k, v) if cache is None else (q, k, v, cache)
|
|
|
@ -219,10 +220,10 @@ class MultiHeadAttention(Layer):
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
k = self.k_proj(key)
|
|
|
|
k = self.k_proj(key)
|
|
|
|
v = self.v_proj(value)
|
|
|
|
v = self.v_proj(value)
|
|
|
|
k = layers.reshape(x=k, shape=[0, 0, self.num_heads, self.head_dim])
|
|
|
|
k = tensor.reshape(x=k, shape=[0, 0, self.num_heads, self.head_dim])
|
|
|
|
k = layers.transpose(x=k, perm=[0, 2, 1, 3])
|
|
|
|
k = tensor.transpose(x=k, perm=[0, 2, 1, 3])
|
|
|
|
v = layers.reshape(x=v, shape=[0, 0, self.num_heads, self.head_dim])
|
|
|
|
v = tensor.reshape(x=v, shape=[0, 0, self.num_heads, self.head_dim])
|
|
|
|
v = layers.transpose(x=v, perm=[0, 2, 1, 3])
|
|
|
|
v = tensor.transpose(x=v, perm=[0, 2, 1, 3])
|
|
|
|
return k, v
|
|
|
|
return k, v
|
|
|
|
|
|
|
|
|
|
|
|
def gen_cache(self, key, value=None, type=Cache):
|
|
|
|
def gen_cache(self, key, value=None, type=Cache):
|
|
|
@ -352,24 +353,25 @@ class MultiHeadAttention(Layer):
|
|
|
|
q, k, v, cache = self._prepare_qkv(query, key, value, cache)
|
|
|
|
q, k, v, cache = self._prepare_qkv(query, key, value, cache)
|
|
|
|
|
|
|
|
|
|
|
|
# scale dot product attention
|
|
|
|
# scale dot product attention
|
|
|
|
|
|
|
|
# TODO(guosheng): use tensor.matmul, however it doesn't support `alpha`
|
|
|
|
product = layers.matmul(
|
|
|
|
product = layers.matmul(
|
|
|
|
x=q, y=k, transpose_y=True, alpha=self.head_dim**-0.5)
|
|
|
|
x=q, y=k, transpose_y=True, alpha=self.head_dim**-0.5)
|
|
|
|
if attn_mask is not None:
|
|
|
|
if attn_mask is not None:
|
|
|
|
# TODO(guosheng): support bool mask
|
|
|
|
# TODO(guosheng): support bool mask
|
|
|
|
product = product + attn_mask
|
|
|
|
product = product + attn_mask
|
|
|
|
weights = layers.softmax(product)
|
|
|
|
weights = F.softmax(product)
|
|
|
|
if self.dropout:
|
|
|
|
if self.dropout:
|
|
|
|
weights = layers.dropout(
|
|
|
|
weights = F.dropout(
|
|
|
|
weights,
|
|
|
|
weights,
|
|
|
|
dropout_prob=self.dropout,
|
|
|
|
self.dropout,
|
|
|
|
dropout_implementation="upscale_in_train",
|
|
|
|
training=self.training,
|
|
|
|
is_test=False)
|
|
|
|
mode="upscale_in_train")
|
|
|
|
|
|
|
|
|
|
|
|
out = layers.matmul(weights, v)
|
|
|
|
out = tensor.matmul(weights, v)
|
|
|
|
|
|
|
|
|
|
|
|
# combine heads
|
|
|
|
# combine heads
|
|
|
|
out = layers.transpose(out, perm=[0, 2, 1, 3])
|
|
|
|
out = tensor.transpose(out, perm=[0, 2, 1, 3])
|
|
|
|
out = layers.reshape(x=out, shape=[0, 0, out.shape[2] * out.shape[3]])
|
|
|
|
out = tensor.reshape(x=out, shape=[0, 0, out.shape[2] * out.shape[3]])
|
|
|
|
|
|
|
|
|
|
|
|
# project to output
|
|
|
|
# project to output
|
|
|
|
out = self.out_proj(out)
|
|
|
|
out = self.out_proj(out)
|
|
|
@ -429,7 +431,7 @@ class TransformerEncoderLayer(Layer):
|
|
|
|
.. code-block:: python
|
|
|
|
.. code-block:: python
|
|
|
|
|
|
|
|
|
|
|
|
import paddle
|
|
|
|
import paddle
|
|
|
|
from paddle import TransformerEncoderLayer
|
|
|
|
from paddle.nn import TransformerEncoderLayer
|
|
|
|
|
|
|
|
|
|
|
|
# encoder input: [batch_size, src_len, d_model]
|
|
|
|
# encoder input: [batch_size, src_len, d_model]
|
|
|
|
enc_input = paddle.rand((2, 4, 128))
|
|
|
|
enc_input = paddle.rand((2, 4, 128))
|
|
|
@ -470,17 +472,14 @@ class TransformerEncoderLayer(Layer):
|
|
|
|
bias_attr=bias_attrs[0])
|
|
|
|
bias_attr=bias_attrs[0])
|
|
|
|
self.linear1 = Linear(
|
|
|
|
self.linear1 = Linear(
|
|
|
|
d_model, dim_feedforward, weight_attrs[1], bias_attr=bias_attrs[1])
|
|
|
|
d_model, dim_feedforward, weight_attrs[1], bias_attr=bias_attrs[1])
|
|
|
|
self.dropout = Dropout(
|
|
|
|
self.dropout = Dropout(act_dropout, mode="upscale_in_train")
|
|
|
|
act_dropout, dropout_implementation="upscale_in_train")
|
|
|
|
|
|
|
|
self.linear2 = Linear(
|
|
|
|
self.linear2 = Linear(
|
|
|
|
dim_feedforward, d_model, weight_attrs[1], bias_attr=bias_attrs[1])
|
|
|
|
dim_feedforward, d_model, weight_attrs[1], bias_attr=bias_attrs[1])
|
|
|
|
self.norm1 = LayerNorm(d_model)
|
|
|
|
self.norm1 = LayerNorm(d_model)
|
|
|
|
self.norm2 = LayerNorm(d_model)
|
|
|
|
self.norm2 = LayerNorm(d_model)
|
|
|
|
self.dropout1 = Dropout(
|
|
|
|
self.dropout1 = Dropout(dropout, mode="upscale_in_train")
|
|
|
|
dropout, dropout_implementation="upscale_in_train")
|
|
|
|
self.dropout2 = Dropout(dropout, mode="upscale_in_train")
|
|
|
|
self.dropout2 = Dropout(
|
|
|
|
self.activation = getattr(F, activation)
|
|
|
|
dropout, dropout_implementation="upscale_in_train")
|
|
|
|
|
|
|
|
self.activation = getattr(layers, activation)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def forward(self, src, src_mask=None):
|
|
|
|
def forward(self, src, src_mask=None):
|
|
|
|
"""
|
|
|
|
"""
|
|
|
@ -539,7 +538,7 @@ class TransformerEncoder(Layer):
|
|
|
|
.. code-block:: python
|
|
|
|
.. code-block:: python
|
|
|
|
|
|
|
|
|
|
|
|
import paddle
|
|
|
|
import paddle
|
|
|
|
from paddle import TransformerEncoderLayer, TransformerEncoder
|
|
|
|
from paddle.nn import TransformerEncoderLayer, TransformerEncoder
|
|
|
|
|
|
|
|
|
|
|
|
# encoder input: [batch_size, src_len, d_model]
|
|
|
|
# encoder input: [batch_size, src_len, d_model]
|
|
|
|
enc_input = paddle.rand((2, 4, 128))
|
|
|
|
enc_input = paddle.rand((2, 4, 128))
|
|
|
@ -643,7 +642,7 @@ class TransformerDecoderLayer(Layer):
|
|
|
|
.. code-block:: python
|
|
|
|
.. code-block:: python
|
|
|
|
|
|
|
|
|
|
|
|
import paddle
|
|
|
|
import paddle
|
|
|
|
from paddle import TransformerDecoderLayer
|
|
|
|
from paddle.nn import TransformerDecoderLayer
|
|
|
|
|
|
|
|
|
|
|
|
# decoder input: [batch_size, tgt_len, d_model]
|
|
|
|
# decoder input: [batch_size, tgt_len, d_model]
|
|
|
|
dec_input = paddle.rand((2, 4, 128))
|
|
|
|
dec_input = paddle.rand((2, 4, 128))
|
|
|
@ -697,20 +696,16 @@ class TransformerDecoderLayer(Layer):
|
|
|
|
bias_attr=bias_attrs[1])
|
|
|
|
bias_attr=bias_attrs[1])
|
|
|
|
self.linear1 = Linear(
|
|
|
|
self.linear1 = Linear(
|
|
|
|
d_model, dim_feedforward, weight_attrs[2], bias_attr=bias_attrs[2])
|
|
|
|
d_model, dim_feedforward, weight_attrs[2], bias_attr=bias_attrs[2])
|
|
|
|
self.dropout = Dropout(
|
|
|
|
self.dropout = Dropout(act_dropout, mode="upscale_in_train")
|
|
|
|
act_dropout, dropout_implementation="upscale_in_train")
|
|
|
|
|
|
|
|
self.linear2 = Linear(
|
|
|
|
self.linear2 = Linear(
|
|
|
|
dim_feedforward, d_model, weight_attrs[2], bias_attr=bias_attrs[2])
|
|
|
|
dim_feedforward, d_model, weight_attrs[2], bias_attr=bias_attrs[2])
|
|
|
|
self.norm1 = LayerNorm(d_model)
|
|
|
|
self.norm1 = LayerNorm(d_model)
|
|
|
|
self.norm2 = LayerNorm(d_model)
|
|
|
|
self.norm2 = LayerNorm(d_model)
|
|
|
|
self.norm3 = LayerNorm(d_model)
|
|
|
|
self.norm3 = LayerNorm(d_model)
|
|
|
|
self.dropout1 = Dropout(
|
|
|
|
self.dropout1 = Dropout(dropout, mode="upscale_in_train")
|
|
|
|
dropout, dropout_implementation="upscale_in_train")
|
|
|
|
self.dropout2 = Dropout(dropout, mode="upscale_in_train")
|
|
|
|
self.dropout2 = Dropout(
|
|
|
|
self.dropout3 = Dropout(dropout, mode="upscale_in_train")
|
|
|
|
dropout, dropout_implementation="upscale_in_train")
|
|
|
|
self.activation = getattr(F, activation)
|
|
|
|
self.dropout3 = Dropout(
|
|
|
|
|
|
|
|
dropout, dropout_implementation="upscale_in_train")
|
|
|
|
|
|
|
|
self.activation = getattr(layers, activation)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def forward(self, tgt, memory, tgt_mask=None, memory_mask=None, cache=None):
|
|
|
|
def forward(self, tgt, memory, tgt_mask=None, memory_mask=None, cache=None):
|
|
|
|
"""
|
|
|
|
"""
|
|
|
@ -834,7 +829,7 @@ class TransformerDecoder(Layer):
|
|
|
|
.. code-block:: python
|
|
|
|
.. code-block:: python
|
|
|
|
|
|
|
|
|
|
|
|
import paddle
|
|
|
|
import paddle
|
|
|
|
from paddle import TransformerDecoderLayer, TransformerDecoder
|
|
|
|
from paddle.nn import TransformerDecoderLayer, TransformerDecoder
|
|
|
|
|
|
|
|
|
|
|
|
# decoder input: [batch_size, tgt_len, d_model]
|
|
|
|
# decoder input: [batch_size, tgt_len, d_model]
|
|
|
|
dec_input = paddle.rand((2, 4, 128))
|
|
|
|
dec_input = paddle.rand((2, 4, 128))
|
|
|
@ -1017,7 +1012,7 @@ class Transformer(Layer):
|
|
|
|
.. code-block:: python
|
|
|
|
.. code-block:: python
|
|
|
|
|
|
|
|
|
|
|
|
import paddle
|
|
|
|
import paddle
|
|
|
|
from paddle import Transformer
|
|
|
|
from paddle.nn import Transformer
|
|
|
|
|
|
|
|
|
|
|
|
# src: [batch_size, tgt_len, d_model]
|
|
|
|
# src: [batch_size, tgt_len, d_model]
|
|
|
|
enc_input = paddle.rand((2, 4, 128))
|
|
|
|
enc_input = paddle.rand((2, 4, 128))
|
|
|
|