|
|
|
@ -844,9 +844,10 @@ class BertModel(nn.Cell):
|
|
|
|
|
attention_mask)
|
|
|
|
|
sequence_output = self.cast(encoder_output[self.last_idx], self.dtype)
|
|
|
|
|
# pooler
|
|
|
|
|
batch_size = P.Shape()(input_ids)[0]
|
|
|
|
|
sequence_slice = self.slice(sequence_output,
|
|
|
|
|
(0, 0, 0),
|
|
|
|
|
(-1, 1, self.hidden_size),
|
|
|
|
|
(batch_size, 1, self.hidden_size),
|
|
|
|
|
(1, 1, 1))
|
|
|
|
|
first_token = self.squeeze_1(sequence_slice)
|
|
|
|
|
pooled_output = self.dense(first_token)
|
|
|
|
@ -939,9 +940,10 @@ class TinyBertModel(nn.Cell):
|
|
|
|
|
attention_mask)
|
|
|
|
|
sequence_output = self.cast(encoder_output[self.last_idx], self.dtype)
|
|
|
|
|
# pooler
|
|
|
|
|
batch_size = P.Shape()(input_ids)[0]
|
|
|
|
|
sequence_slice = self.slice(sequence_output,
|
|
|
|
|
(0, 0, 0),
|
|
|
|
|
(-1, 1, self.hidden_size),
|
|
|
|
|
(batch_size, 1, self.hidden_size),
|
|
|
|
|
(1, 1, 1))
|
|
|
|
|
first_token = self.squeeze_1(sequence_slice)
|
|
|
|
|
pooled_output = self.dense(first_token)
|
|
|
|
|