|
|
|
@ -168,6 +168,11 @@ class BeamSearchDecoder(nn.Cell):
|
|
|
|
|
self.concat = P.Concat(axis=-1)
|
|
|
|
|
self.gather_nd = P.GatherNd()
|
|
|
|
|
|
|
|
|
|
self.greater_equal = P.GreaterEqual()
|
|
|
|
|
self.sub = P.Sub()
|
|
|
|
|
self.cast = P.Cast()
|
|
|
|
|
self.zeroslike = P.ZerosLike()
|
|
|
|
|
|
|
|
|
|
# init inputs and states
|
|
|
|
|
self.start_ids = Tensor(np.full([batch_size * beam_width, 1], sos_id), mstype.int32)
|
|
|
|
|
self.init_seq = Tensor(np.full([batch_size, beam_width, 1], sos_id), mstype.int32)
|
|
|
|
@ -199,8 +204,19 @@ class BeamSearchDecoder(nn.Cell):
|
|
|
|
|
topk_scores, topk_indices = self.topk(flat_scores, self.beam_width)
|
|
|
|
|
|
|
|
|
|
# convert to beam and word indices
|
|
|
|
|
beam_indices = self.floor_div(topk_indices, self.vocab_size_tensor)
|
|
|
|
|
word_indices = self.mod(topk_indices, self.vocab_size_tensor)
|
|
|
|
|
#beam_indices = self.floor_div(topk_indices, self.vocab_size_tensor)
|
|
|
|
|
#word_indices = self.mod(topk_indices, self.vocab_size_tensor)
|
|
|
|
|
#======================================================================
|
|
|
|
|
#replace floor_div and mod op, since these two ops only support fp16 on
|
|
|
|
|
#Ascend310, which will cause overflow.
|
|
|
|
|
temp = topk_indices
|
|
|
|
|
beam_indices = self.zeroslike(topk_indices)
|
|
|
|
|
for _ in range(self.beam_width - 1):
|
|
|
|
|
temp = self.sub(temp, self.vocab_size_tensor)
|
|
|
|
|
res = self.cast(self.greater_equal(temp, 0), mstype.int32)
|
|
|
|
|
beam_indices = beam_indices + res
|
|
|
|
|
word_indices = topk_indices - beam_indices * self.vocab_size_tensor
|
|
|
|
|
#======================================================================
|
|
|
|
|
|
|
|
|
|
# mask finished indices
|
|
|
|
|
beam_indices = self.select(state_finished, self.beam_ids, beam_indices)
|
|
|
|
|