|
|
|
@ -219,7 +219,7 @@ class BeamSearchDecoder(nn.Cell):
|
|
|
|
|
self.start_ids = Tensor(np.full([batch_size * beam_width, 1], sos_id), mstype.int32)
|
|
|
|
|
if self.is_using_while:
|
|
|
|
|
self.start = Tensor(0, dtype=mstype.int32)
|
|
|
|
|
self.init_seq = Tensor(np.full([batch_size, beam_width, self.max_decode_length], sos_id),
|
|
|
|
|
self.init_seq = Tensor(np.full([batch_size, beam_width, self.max_decode_length + 1], sos_id),
|
|
|
|
|
mstype.int32)
|
|
|
|
|
else:
|
|
|
|
|
self.init_seq = Tensor(np.full([batch_size, beam_width, 1], sos_id), mstype.int32)
|
|
|
|
@ -402,7 +402,7 @@ class BeamSearchDecoder(nn.Cell):
|
|
|
|
|
accu_attn_scores = self.accu_attn_scores
|
|
|
|
|
|
|
|
|
|
if not self.is_using_while:
|
|
|
|
|
for _ in range(self.max_decode_length + 1):
|
|
|
|
|
for _ in range(self.max_decode_length):
|
|
|
|
|
cur_input_ids, state_log_probs, state_seq, state_length, decoder_hidden_state, accu_attn_scores, \
|
|
|
|
|
state_finished = self.one_step(cur_input_ids, enc_states, enc_attention_mask, state_log_probs,
|
|
|
|
|
state_seq, state_length, None, decoder_hidden_state, accu_attn_scores,
|
|
|
|
|