From a084532a6e4043ceabaece4aefd2c4bfa690d6db Mon Sep 17 00:00:00 2001 From: linqingke Date: Tue, 27 Oct 2020 16:14:40 +0800 Subject: [PATCH] modify mass beam_search's floordiv and mod to adapt to 310. --- .../nlp/mass/src/transformer/beam_search.py | 20 +++++++++++++++++-- 1 file changed, 18 insertions(+), 2 deletions(-) diff --git a/model_zoo/official/nlp/mass/src/transformer/beam_search.py b/model_zoo/official/nlp/mass/src/transformer/beam_search.py index 4ab25847c6..dfd8b26b50 100644 --- a/model_zoo/official/nlp/mass/src/transformer/beam_search.py +++ b/model_zoo/official/nlp/mass/src/transformer/beam_search.py @@ -224,6 +224,11 @@ class BeamSearchDecoder(nn.Cell): self.one = Tensor(1, mstype.int32) self.prob_concat = P.Concat(axis=1) + self.greater_equal = P.GreaterEqual() + self.sub = P.Sub() + self.cast = P.Cast() + self.zeroslike = P.ZerosLike() + def one_step(self, cur_input_ids, enc_states, enc_attention_mask, state_log_probs, state_seq, state_finished, state_length, entire_log_probs): """ @@ -261,8 +266,19 @@ class BeamSearchDecoder(nn.Cell): topk_scores, topk_indices = self.topk(flat_scores, self.beam_width) # convert to beam and word indices, [batch, beam] - beam_indices = self.floor_div(topk_indices, self.vocab_size_tensor) - word_indices = self.mod(topk_indices, self.vocab_size_tensor) + # beam_indices = self.floor_div(topk_indices, self.vocab_size_tensor) + # word_indices = self.mod(topk_indices, self.vocab_size_tensor) + # ====================================================================== + # replace floor_div and mod op, since these two ops only support fp16 on + # Ascend310, which will cause overflow. + temp = topk_indices + beam_indices = self.zeroslike(topk_indices) + for _ in range(self.beam_width - 1): + temp = self.sub(temp, self.vocab_size_tensor) + res = self.cast(self.greater_equal(temp, 0), mstype.int32) + beam_indices = beam_indices + res + word_indices = topk_indices - beam_indices * self.vocab_size_tensor + #====================================================================== current_word_pro = self.gather_nd( log_probs,