|
|
|
|
@ -20,11 +20,11 @@ import numpy as np
|
|
|
|
|
import mindspore.common.dtype as mstype
|
|
|
|
|
import mindspore.nn as nn
|
|
|
|
|
import mindspore.ops.functional as F
|
|
|
|
|
from mindspore.common.initializer import TruncatedNormal, initializer
|
|
|
|
|
from mindspore.ops import operations as P
|
|
|
|
|
from mindspore.common.tensor import Tensor
|
|
|
|
|
from mindspore.common.parameter import Parameter
|
|
|
|
|
from .beam_search import BeamSearchDecoder, TileBeam
|
|
|
|
|
from .weight_init import normal_weight, weight_variable
|
|
|
|
|
|
|
|
|
|
class TransformerConfig:
|
|
|
|
|
"""
|
|
|
|
|
@ -118,9 +118,7 @@ class EmbeddingLookup(nn.Cell):
|
|
|
|
|
self.vocab_size = vocab_size
|
|
|
|
|
self.embedding_size = embedding_size
|
|
|
|
|
self.use_one_hot_embeddings = use_one_hot_embeddings
|
|
|
|
|
self.embedding_table = Parameter(initializer
|
|
|
|
|
(TruncatedNormal(initializer_range),
|
|
|
|
|
[vocab_size, embedding_size]),
|
|
|
|
|
self.embedding_table = Parameter(normal_weight([vocab_size, embedding_size], embedding_size),
|
|
|
|
|
name='embedding_table')
|
|
|
|
|
self.expand = P.ExpandDims()
|
|
|
|
|
self.shape_flat = (-1,)
|
|
|
|
|
@ -138,8 +136,7 @@ class EmbeddingLookup(nn.Cell):
|
|
|
|
|
flat_ids = self.reshape(input_ids, self.shape_flat)
|
|
|
|
|
if self.use_one_hot_embeddings:
|
|
|
|
|
one_hot_ids = self.one_hot(flat_ids, self.vocab_size, self.on_value, self.off_value)
|
|
|
|
|
output_for_reshape = self.array_mul(
|
|
|
|
|
one_hot_ids, self.embedding_table)
|
|
|
|
|
output_for_reshape = self.array_mul(one_hot_ids, self.embedding_table)
|
|
|
|
|
else:
|
|
|
|
|
output_for_reshape = self.gather(self.embedding_table, flat_ids, 0)
|
|
|
|
|
|
|
|
|
|
@ -329,22 +326,22 @@ class MultiheadAttention(nn.Cell):
|
|
|
|
|
units,
|
|
|
|
|
activation=query_act,
|
|
|
|
|
has_bias=False,
|
|
|
|
|
weight_init=TruncatedNormal(initializer_range)).to_float(compute_type)
|
|
|
|
|
weight_init=weight_variable([units, from_tensor_width])).to_float(compute_type)
|
|
|
|
|
self.key_layer = nn.Dense(to_tensor_width,
|
|
|
|
|
units,
|
|
|
|
|
activation=key_act,
|
|
|
|
|
has_bias=False,
|
|
|
|
|
weight_init=TruncatedNormal(initializer_range)).to_float(compute_type)
|
|
|
|
|
weight_init=weight_variable([units, to_tensor_width])).to_float(compute_type)
|
|
|
|
|
self.value_layer = nn.Dense(to_tensor_width,
|
|
|
|
|
units,
|
|
|
|
|
activation=value_act,
|
|
|
|
|
has_bias=False,
|
|
|
|
|
weight_init=TruncatedNormal(initializer_range)).to_float(compute_type)
|
|
|
|
|
weight_init=weight_variable([units, to_tensor_width])).to_float(compute_type)
|
|
|
|
|
self.out_layer = nn.Dense(units,
|
|
|
|
|
out_tensor_width,
|
|
|
|
|
activation=out_act,
|
|
|
|
|
has_bias=False,
|
|
|
|
|
weight_init=TruncatedNormal(initializer_range)).to_float(compute_type)
|
|
|
|
|
weight_init=weight_variable([out_tensor_width, units])).to_float(compute_type)
|
|
|
|
|
|
|
|
|
|
self.shape_from = (batch_size, from_seq_length, num_attention_heads, size_per_head)
|
|
|
|
|
self.shape_to = (batch_size, to_seq_length, num_attention_heads, size_per_head)
|
|
|
|
|
@ -518,10 +515,10 @@ class FeedForward(nn.Cell):
|
|
|
|
|
self.conv1 = nn.Dense(in_channels,
|
|
|
|
|
hidden_size,
|
|
|
|
|
activation=hidden_act,
|
|
|
|
|
weight_init=TruncatedNormal(initializer_range)).to_float(compute_type)
|
|
|
|
|
weight_init=weight_variable([hidden_size, in_channels])).to_float(compute_type)
|
|
|
|
|
self.conv2 = nn.Dense(hidden_size,
|
|
|
|
|
out_channels,
|
|
|
|
|
weight_init=TruncatedNormal(initializer_range)).to_float(compute_type)
|
|
|
|
|
weight_init=weight_variable([out_channels, hidden_size])).to_float(compute_type)
|
|
|
|
|
|
|
|
|
|
self.preprocess = LayerPreprocess(in_channels=in_channels)
|
|
|
|
|
self.postprocess = LayerPostprocess(dropout_prob=hidden_dropout_prob)
|
|
|
|
|
@ -1108,7 +1105,13 @@ class TransformerModel(nn.Cell):
|
|
|
|
|
embedding_size=self.embedding_size,
|
|
|
|
|
use_one_hot_embeddings=use_one_hot_embeddings,
|
|
|
|
|
initializer_range=config.initializer_range)
|
|
|
|
|
self.tfm_embedding_postprocessor = EmbeddingPostprocessor(
|
|
|
|
|
self.tfm_embedding_postprocessor_for_encoder = EmbeddingPostprocessor(
|
|
|
|
|
embedding_size=self.embedding_size,
|
|
|
|
|
use_one_hot_embeddings=use_one_hot_embeddings,
|
|
|
|
|
initializer_range=0.02,
|
|
|
|
|
max_position_embeddings=config.max_position_embeddings,
|
|
|
|
|
dropout_prob=config.hidden_dropout_prob)
|
|
|
|
|
self.tfm_embedding_postprocessor_for_decoder = EmbeddingPostprocessor(
|
|
|
|
|
embedding_size=self.embedding_size,
|
|
|
|
|
use_one_hot_embeddings=use_one_hot_embeddings,
|
|
|
|
|
initializer_range=0.02,
|
|
|
|
|
@ -1171,7 +1174,7 @@ class TransformerModel(nn.Cell):
|
|
|
|
|
hidden_act=config.hidden_act,
|
|
|
|
|
compute_type=config.compute_type,
|
|
|
|
|
embedding_lookup=self.tfm_embedding_lookup,
|
|
|
|
|
embedding_processor=self.tfm_embedding_postprocessor,
|
|
|
|
|
embedding_processor=self.tfm_embedding_postprocessor_for_decoder,
|
|
|
|
|
projection=self.projection)
|
|
|
|
|
self.tfm_decoder = BeamSearchDecoder(
|
|
|
|
|
batch_size=config.batch_size,
|
|
|
|
|
@ -1195,15 +1198,14 @@ class TransformerModel(nn.Cell):
|
|
|
|
|
ones = np.ones(shape=(self.seq_length, self.seq_length))
|
|
|
|
|
self.future_mask = Tensor(np.tril(ones), dtype=mstype.float32)
|
|
|
|
|
else:
|
|
|
|
|
self.tile_beam = TileBeam(
|
|
|
|
|
beam_width=config.beam_width)
|
|
|
|
|
self.tile_beam = TileBeam(beam_width=config.beam_width)
|
|
|
|
|
ones = np.ones(shape=(config.batch_size, config.max_decode_length))
|
|
|
|
|
self.encdec_mask = Tensor(ones, dtype=mstype.float32)
|
|
|
|
|
|
|
|
|
|
def construct(self, source_ids, source_mask, target_ids=None, target_mask=None):
|
|
|
|
|
# process source sentence
|
|
|
|
|
src_word_embeddings, embedding_tables = self.tfm_embedding_lookup(source_ids)
|
|
|
|
|
src_embedding_output = self.tfm_embedding_postprocessor(src_word_embeddings)
|
|
|
|
|
src_embedding_output = self.tfm_embedding_postprocessor_for_encoder(src_word_embeddings)
|
|
|
|
|
# attention mask [batch_size, seq_length, seq_length]
|
|
|
|
|
enc_attention_mask = self._create_attention_mask_from_input_mask(source_mask)
|
|
|
|
|
# transformer encoder
|
|
|
|
|
@ -1213,7 +1215,7 @@ class TransformerModel(nn.Cell):
|
|
|
|
|
if self.is_training:
|
|
|
|
|
# process target sentence
|
|
|
|
|
tgt_word_embeddings, _ = self.tfm_embedding_lookup(target_ids)
|
|
|
|
|
tgt_embedding_output = self.tfm_embedding_postprocessor(tgt_word_embeddings)
|
|
|
|
|
tgt_embedding_output = self.tfm_embedding_postprocessor_for_decoder(tgt_word_embeddings)
|
|
|
|
|
# attention mask [batch_size, seq_length, seq_length]
|
|
|
|
|
tgt_attention_mask = self._create_attention_mask_from_input_mask(target_mask)
|
|
|
|
|
tgt_attention_mask = self.multiply(tgt_attention_mask, self.expand(self.future_mask, 0))
|
|
|
|
|
@ -1223,15 +1225,14 @@ class TransformerModel(nn.Cell):
|
|
|
|
|
encoder_output, enc_attention_mask)
|
|
|
|
|
# calculate logits and log_probs
|
|
|
|
|
log_probs = self.projection(decoder_output, embedding_tables)
|
|
|
|
|
return log_probs
|
|
|
|
|
|
|
|
|
|
beam_encoder_output = self.tile_beam(encoder_output)
|
|
|
|
|
ret = log_probs
|
|
|
|
|
else:
|
|
|
|
|
beam_encoder_output = self.tile_beam(encoder_output)
|
|
|
|
|
|
|
|
|
|
enc_attention_mask = self.multiply(
|
|
|
|
|
enc_attention_mask[::, 0:1:1, ::],
|
|
|
|
|
self.expand(self.encdec_mask, -1))
|
|
|
|
|
enc_attention_mask = self.multiply(enc_attention_mask[::, 0:1:1, ::], self.expand(self.encdec_mask, -1))
|
|
|
|
|
|
|
|
|
|
beam_enc_attention_mask = self.tile_beam(enc_attention_mask)
|
|
|
|
|
beam_enc_attention_mask = self.cast_compute_type(beam_enc_attention_mask)
|
|
|
|
|
predicted_ids = self.tfm_decoder(beam_encoder_output, beam_enc_attention_mask)
|
|
|
|
|
return predicted_ids
|
|
|
|
|
beam_enc_attention_mask = self.tile_beam(enc_attention_mask)
|
|
|
|
|
beam_enc_attention_mask = self.cast_compute_type(beam_enc_attention_mask)
|
|
|
|
|
predicted_ids = self.tfm_decoder(beam_encoder_output, beam_enc_attention_mask)
|
|
|
|
|
ret = predicted_ids
|
|
|
|
|
return ret
|
|
|
|
|
|