diff --git a/model_zoo/official/nlp/tinybert/src/tinybert_model.py b/model_zoo/official/nlp/tinybert/src/tinybert_model.py index 5e8dc8436b..09504abcd8 100644 --- a/model_zoo/official/nlp/tinybert/src/tinybert_model.py +++ b/model_zoo/official/nlp/tinybert/src/tinybert_model.py @@ -844,9 +844,10 @@ class BertModel(nn.Cell): attention_mask) sequence_output = self.cast(encoder_output[self.last_idx], self.dtype) # pooler + batch_size = P.Shape()(input_ids)[0] sequence_slice = self.slice(sequence_output, (0, 0, 0), - (-1, 1, self.hidden_size), + (batch_size, 1, self.hidden_size), (1, 1, 1)) first_token = self.squeeze_1(sequence_slice) pooled_output = self.dense(first_token) @@ -939,9 +940,10 @@ class TinyBertModel(nn.Cell): attention_mask) sequence_output = self.cast(encoder_output[self.last_idx], self.dtype) # pooler + batch_size = P.Shape()(input_ids)[0] sequence_slice = self.slice(sequence_output, (0, 0, 0), - (-1, 1, self.hidden_size), + (batch_size, 1, self.hidden_size), (1, 1, 1)) first_token = self.squeeze_1(sequence_slice) pooled_output = self.dense(first_token)