|
|
|
@ -131,6 +131,7 @@ class EmbeddingLookup(nn.Cell):
|
|
|
|
|
self.shape = P.Shape()
|
|
|
|
|
|
|
|
|
|
def construct(self, input_ids):
|
|
|
|
|
"""Get a embeddings lookup table with a fixed dictionary and size."""
|
|
|
|
|
input_shape = self.shape(input_ids)
|
|
|
|
|
|
|
|
|
|
flat_ids = self.reshape(input_ids, self.shape_flat)
|
|
|
|
@ -200,6 +201,7 @@ class EmbeddingPostprocessor(nn.Cell):
|
|
|
|
|
self.shape = P.Shape()
|
|
|
|
|
|
|
|
|
|
def construct(self, word_embeddings):
|
|
|
|
|
"""Postprocessors apply positional embeddings to word embeddings."""
|
|
|
|
|
input_shape = self.shape(word_embeddings)
|
|
|
|
|
input_len = input_shape[1]
|
|
|
|
|
|
|
|
|
@ -377,7 +379,7 @@ class MultiheadAttention(nn.Cell):
|
|
|
|
|
self.softmax_cast = P.Cast()
|
|
|
|
|
|
|
|
|
|
def construct(self, from_tensor, to_tensor, attention_mask=None):
|
|
|
|
|
# reshape 2d/3d input tensors to 2d
|
|
|
|
|
"""reshape 2d/3d input tensors to 2d"""
|
|
|
|
|
from_tensor_2d = self.reshape(from_tensor, self.shape_from_2d)
|
|
|
|
|
to_tensor_2d = self.reshape(to_tensor, self.shape_to_2d)
|
|
|
|
|
query_out = self.query_layer(from_tensor_2d)
|
|
|
|
@ -476,6 +478,7 @@ class SelfAttention(nn.Cell):
|
|
|
|
|
self.reshape = P.Reshape()
|
|
|
|
|
self.shape = (-1, hidden_size)
|
|
|
|
|
def construct(self, input_tensor, memory_tensor, attention_mask):
|
|
|
|
|
"""Apply self-attention."""
|
|
|
|
|
input_tensor = self.reshape(input_tensor, self.shape)
|
|
|
|
|
memory_tensor = self.reshape(memory_tensor, self.shape)
|
|
|
|
|
|
|
|
|
@ -831,6 +834,7 @@ class CreateAttentionMaskFromInputMask(nn.Cell):
|
|
|
|
|
self.batch_matmul = P.BatchMatMul()
|
|
|
|
|
|
|
|
|
|
def construct(self, input_mask):
|
|
|
|
|
"""Create attention mask according to input mask."""
|
|
|
|
|
input_shape = self.shape(input_mask)
|
|
|
|
|
shape_right = (input_shape[0], 1, input_shape[1])
|
|
|
|
|
shape_left = input_shape + (1,)
|
|
|
|
@ -876,6 +880,7 @@ class PredLogProbs(nn.Cell):
|
|
|
|
|
def construct(self,
|
|
|
|
|
input_tensor,
|
|
|
|
|
output_weights):
|
|
|
|
|
"""Get log probs."""
|
|
|
|
|
input_tensor = self.reshape(input_tensor, self.shape_flat_sequence_tensor)
|
|
|
|
|
input_tensor = self.cast(input_tensor, self.compute_type)
|
|
|
|
|
output_weights = self.cast(output_weights, self.compute_type)
|
|
|
|
@ -962,7 +967,10 @@ class TransformerDecoderStep(nn.Cell):
|
|
|
|
|
self.cast_compute_type = CastWrapper(dst_type=compute_type)
|
|
|
|
|
|
|
|
|
|
def construct(self, input_ids, enc_states, enc_attention_mask):
|
|
|
|
|
# input_ids: [batch_size * beam_width]
|
|
|
|
|
"""
|
|
|
|
|
Multi-layer transformer decoder step.
|
|
|
|
|
input_ids: [batch_size * beam_width]
|
|
|
|
|
"""
|
|
|
|
|
# process embedding
|
|
|
|
|
input_embedding, embedding_tables = self.tfm_embedding_lookup(input_ids)
|
|
|
|
|
input_embedding = self.tfm_embedding_processor(input_embedding)
|
|
|
|
@ -1122,6 +1130,7 @@ class TransformerModel(nn.Cell):
|
|
|
|
|
self.encdec_mask = Tensor(ones, dtype=mstype.float32)
|
|
|
|
|
|
|
|
|
|
def construct(self, source_ids, source_mask, target_ids=None, target_mask=None):
|
|
|
|
|
"""Transformer with encoder and decoder."""
|
|
|
|
|
# process source sentence
|
|
|
|
|
src_word_embeddings, embedding_tables = self.tfm_embedding_lookup(source_ids)
|
|
|
|
|
src_embedding_output = self.tfm_embedding_postprocessor_for_encoder(src_word_embeddings)
|
|
|
|
|