|
|
|
@ -311,7 +311,7 @@ class MultiHeadAttention(Layer):
|
|
|
|
|
# incremental_state with initial value, mainly for usage like UniLM
|
|
|
|
|
return self.Cache(key, value)
|
|
|
|
|
|
|
|
|
|
def forward(self, query, key, value, attn_mask=None, cache=None):
|
|
|
|
|
def forward(self, query, key=None, value=None, attn_mask=None, cache=None):
|
|
|
|
|
r"""
|
|
|
|
|
Applies multi-head attention to map queries and a set of key-value pairs
|
|
|
|
|
to outputs.
|
|
|
|
@ -498,7 +498,7 @@ class TransformerEncoderLayer(Layer):
|
|
|
|
|
self.dropout2 = Dropout(dropout, mode="upscale_in_train")
|
|
|
|
|
self.activation = getattr(F, activation)
|
|
|
|
|
|
|
|
|
|
def forward(self, src, src_mask=None):
|
|
|
|
|
def forward(self, src, src_mask=None, cache=None):
|
|
|
|
|
r"""
|
|
|
|
|
Applies a Transformer encoder layer on the input.
|
|
|
|
|
|
|
|
|
@ -514,16 +514,30 @@ class TransformerEncoderLayer(Layer):
|
|
|
|
|
have 0 values. The data type should be float32 or float64. It can
|
|
|
|
|
be None when nothing wanted or needed to be prevented attention to.
|
|
|
|
|
Default None
|
|
|
|
|
cache (Tensor, optional): It is an instance of `MultiHeadAttention.Cache`.
|
|
|
|
|
See `TransformerEncoderLayer.gen_cache` for more details. It is
|
|
|
|
|
only used for inference and should be None for training. Default
|
|
|
|
|
None.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
Tensor: The output of Transformer encoder layer. It is a tensor that \
|
|
|
|
|
has the same shape and data type as `enc_input`.
|
|
|
|
|
Tensor|tuple: It is a tensor that has the same shape and data type \
|
|
|
|
|
as `enc_input`, representing the output of Transformer encoder \
|
|
|
|
|
layer. Or a tuple if `cache` is not None, except for encoder \
|
|
|
|
|
layer output, the tuple includes the new cache which is same \
|
|
|
|
|
as input `cache` argument but `incremental_cache` has an \
|
|
|
|
|
incremental length. See `MultiHeadAttention.gen_cache` and \
|
|
|
|
|
`MultiHeadAttention.forward` for more details.
|
|
|
|
|
"""
|
|
|
|
|
residual = src
|
|
|
|
|
if self.normalize_before:
|
|
|
|
|
src = self.norm1(src)
|
|
|
|
|
# TODO(guosheng): Add cache for encoder for the usage like UniLM
|
|
|
|
|
src = self.self_attn(src, src, src, src_mask)
|
|
|
|
|
if cache is None:
|
|
|
|
|
src = self.self_attn(src, src, src, src_mask)
|
|
|
|
|
else:
|
|
|
|
|
src, incremental_cache = self.self_attn(src, src, src, src_mask,
|
|
|
|
|
cache)
|
|
|
|
|
|
|
|
|
|
src = residual + self.dropout1(src)
|
|
|
|
|
if not self.normalize_before:
|
|
|
|
|
src = self.norm1(src)
|
|
|
|
@ -535,7 +549,28 @@ class TransformerEncoderLayer(Layer):
|
|
|
|
|
src = residual + self.dropout2(src)
|
|
|
|
|
if not self.normalize_before:
|
|
|
|
|
src = self.norm2(src)
|
|
|
|
|
return src
|
|
|
|
|
return src if cache is None else (src, incremental_cache)
|
|
|
|
|
|
|
|
|
|
def gen_cache(self, src):
|
|
|
|
|
r"""
|
|
|
|
|
Generates cache for `forward` usage. The generated cache is an
|
|
|
|
|
instance of `MultiHeadAttention.Cache`.
|
|
|
|
|
|
|
|
|
|
Parameters:
|
|
|
|
|
src (Tensor): The input of Transformer encoder. It is a tensor
|
|
|
|
|
with shape `[batch_size, source_length, d_model]`. The data
|
|
|
|
|
type should be float32 or float64.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
incremental_cache: It is an instance of `MultiHeadAttention.Cache` \
|
|
|
|
|
produced by `self_attn.gen_cache`, it reserves two tensors
|
|
|
|
|
shaped `[batch_size, nhead, 0, d_model // nhead]`. See \
|
|
|
|
|
`MultiHeadAttention.gen_cache` and `MultiHeadAttention.forward` \
|
|
|
|
|
for more details.
|
|
|
|
|
"""
|
|
|
|
|
incremental_cache = self.self_attn.gen_cache(
|
|
|
|
|
src, type=self.self_attn.Cache)
|
|
|
|
|
return incremental_cache
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TransformerEncoder(Layer):
|
|
|
|
@ -574,7 +609,7 @@ class TransformerEncoder(Layer):
|
|
|
|
|
self.num_layers = num_layers
|
|
|
|
|
self.norm = norm
|
|
|
|
|
|
|
|
|
|
def forward(self, src, src_mask=None):
|
|
|
|
|
def forward(self, src, src_mask=None, cache=None):
|
|
|
|
|
r"""
|
|
|
|
|
Applies a stack of N Transformer encoder layers on inputs. If `norm` is
|
|
|
|
|
provided, also applies layer normalization on the output of last encoder
|
|
|
|
@ -592,20 +627,55 @@ class TransformerEncoder(Layer):
|
|
|
|
|
have 0 values. The data type should be float32 or float64. It can
|
|
|
|
|
be None when nothing wanted or needed to be prevented attention to.
|
|
|
|
|
Default None
|
|
|
|
|
cache (list, optional): It is a list, and each element in the list
|
|
|
|
|
is `incremental_cache` produced by `TransformerEncoderLayer.gen_cache`.
|
|
|
|
|
See `TransformerEncoder.gen_cache` for more details. It is only
|
|
|
|
|
used for inference and should be None for training. Default None.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
Tensor: The output of Transformer encoder. It is a tensor that \
|
|
|
|
|
has the same shape and data type as `src`.
|
|
|
|
|
Tensor|tuple: It is a tensor that has the same shape and data type \
|
|
|
|
|
as `src`, representing the output of Transformer encoder. \
|
|
|
|
|
Or a tuple if `cache` is not None, except for encoder output, \
|
|
|
|
|
the tuple includes the new cache which is same as input `cache` \
|
|
|
|
|
argument but `incremental_cache` in it has an incremental length. \
|
|
|
|
|
See `MultiHeadAttention.gen_cache` and `MultiHeadAttention.forward` \
|
|
|
|
|
for more details.
|
|
|
|
|
"""
|
|
|
|
|
output = src
|
|
|
|
|
|
|
|
|
|
for mod in self.layers:
|
|
|
|
|
output = mod(output, src_mask=src_mask)
|
|
|
|
|
new_caches = []
|
|
|
|
|
for i, mod in enumerate(self.layers):
|
|
|
|
|
if cache is None:
|
|
|
|
|
output = mod(output, src_mask=src_mask)
|
|
|
|
|
else:
|
|
|
|
|
output, new_cache = mod(output,
|
|
|
|
|
src_mask=src_mask,
|
|
|
|
|
cache=cache[i])
|
|
|
|
|
new_caches.append(new_cache)
|
|
|
|
|
|
|
|
|
|
if self.norm is not None:
|
|
|
|
|
output = self.norm(output)
|
|
|
|
|
|
|
|
|
|
return output
|
|
|
|
|
return output if cache is None else (output, new_caches)
|
|
|
|
|
|
|
|
|
|
def gen_cache(self, src):
|
|
|
|
|
r"""
|
|
|
|
|
Generates cache for `forward` usage. The generated cache is a list, and
|
|
|
|
|
each element in it is `incremental_cache` produced by
|
|
|
|
|
`TransformerEncoderLayer.gen_cache`. See `TransformerEncoderLayer.gen_cache`
|
|
|
|
|
for more details.
|
|
|
|
|
|
|
|
|
|
Parameters:
|
|
|
|
|
src (Tensor): The input of Transformer encoder. It is a tensor
|
|
|
|
|
with shape `[batch_size, source_length, d_model]`. The data type
|
|
|
|
|
should be float32 or float64.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
list: It is a list, and each element in the list is `incremental_cache`
|
|
|
|
|
produced by `TransformerEncoderLayer.gen_cache`. See
|
|
|
|
|
`TransformerEncoderLayer.gen_cache` for more details.
|
|
|
|
|
"""
|
|
|
|
|
cache = [layer.gen_cache(src) for layer in self.layers]
|
|
|
|
|
return cache
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TransformerDecoderLayer(Layer):
|
|
|
|
|