|
|
|
@ -964,7 +964,7 @@ class BertModelCLS(nn.Cell):
|
|
|
|
|
The returned output represents the final logits as the results of log_softmax is propotional to that of softmax.
|
|
|
|
|
"""
|
|
|
|
|
def __init__(self, config, is_training, num_labels=2, dropout_prob=0.0,
|
|
|
|
|
use_one_hot_embeddings=False, phase_type="teacher"):
|
|
|
|
|
use_one_hot_embeddings=False, phase_type="student"):
|
|
|
|
|
super(BertModelCLS, self).__init__()
|
|
|
|
|
self.bert = BertModel(config, is_training, use_one_hot_embeddings)
|
|
|
|
|
self.cast = P.Cast()
|
|
|
|
@ -992,4 +992,6 @@ class BertModelCLS(nn.Cell):
|
|
|
|
|
logits = self.dense_1(cls)
|
|
|
|
|
logits = self.cast(logits, self.dtype)
|
|
|
|
|
log_probs = self.log_softmax(logits)
|
|
|
|
|
return seq_output, att_output, logits, log_probs
|
|
|
|
|
if self._phase == 'train' or self.phase_type == "teacher":
|
|
|
|
|
return seq_output, att_output, logits, log_probs
|
|
|
|
|
return log_probs
|
|
|
|
|