From 930fdc4e3132988e6b99ec68f33e7db747c500bc Mon Sep 17 00:00:00 2001 From: gaojing Date: Mon, 1 Feb 2021 08:03:35 -0500 Subject: [PATCH] modify uncorrect name --- model_zoo/official/nlp/gnmt_v2/src/gnmt_model/attention.py | 4 ++-- model_zoo/official/nlp/gnmt_v2/src/gnmt_model/beam_search.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/model_zoo/official/nlp/gnmt_v2/src/gnmt_model/attention.py b/model_zoo/official/nlp/gnmt_v2/src/gnmt_model/attention.py index 68fa456dab..401efc9ac7 100644 --- a/model_zoo/official/nlp/gnmt_v2/src/gnmt_model/attention.py +++ b/model_zoo/official/nlp/gnmt_v2/src/gnmt_model/attention.py @@ -115,7 +115,7 @@ class BahdanauAttention(nn.Cell): if self.is_training: query_trans = self.cast(query_trans, mstype.float16) processed_query = self.linear_q(query_trans) - if self.is_trining: + if self.is_training: processed_query = self.cast(processed_query, mstype.float32) processed_query = self.reshape(processed_query, (batch_size, t_q_length, self.num_units)) # (N, t_k_length, D) @@ -123,7 +123,7 @@ class BahdanauAttention(nn.Cell): if self.is_training: keys = self.cast(keys, mstype.float16) processed_key = self.linear_k(keys) - if self.is_trining: + if self.is_training: processed_key = self.cast(processed_key, mstype.float32) processed_key = self.reshape(processed_key, (batch_size, t_k_length, self.num_units)) diff --git a/model_zoo/official/nlp/gnmt_v2/src/gnmt_model/beam_search.py b/model_zoo/official/nlp/gnmt_v2/src/gnmt_model/beam_search.py index 99a0ced031..dde14062c3 100644 --- a/model_zoo/official/nlp/gnmt_v2/src/gnmt_model/beam_search.py +++ b/model_zoo/official/nlp/gnmt_v2/src/gnmt_model/beam_search.py @@ -219,7 +219,7 @@ class BeamSearchDecoder(nn.Cell): self.start_ids = Tensor(np.full([batch_size * beam_width, 1], sos_id), mstype.int32) if self.is_using_while: self.start = Tensor(0, dtype=mstype.int32) - self.init_seq = Tensor(np.full([batch_size, beam_width, self.max_decode_length], sos_id), + self.init_seq = Tensor(np.full([batch_size, beam_width, self.max_decode_length + 1], sos_id), mstype.int32) else: self.init_seq = Tensor(np.full([batch_size, beam_width, 1], sos_id), mstype.int32) @@ -402,7 +402,7 @@ class BeamSearchDecoder(nn.Cell): accu_attn_scores = self.accu_attn_scores if not self.is_using_while: - for _ in range(self.max_decode_length + 1): + for _ in range(self.max_decode_length): cur_input_ids, state_log_probs, state_seq, state_length, decoder_hidden_state, accu_attn_scores, \ state_finished = self.one_step(cur_input_ids, enc_states, enc_attention_mask, state_log_probs, state_seq, state_length, None, decoder_hidden_state, accu_attn_scores,