diff --git a/model_zoo/official/nlp/bert/src/bert_for_pre_training.py b/model_zoo/official/nlp/bert/src/bert_for_pre_training.py index 4459d9842c..64f6c96a72 100644 --- a/model_zoo/official/nlp/bert/src/bert_for_pre_training.py +++ b/model_zoo/official/nlp/bert/src/bert_for_pre_training.py @@ -258,7 +258,7 @@ class BertNetworkWithLoss(nn.Cell): return self.cast(total_loss, mstype.float32) -class BertTrainOneStepCell(nn.Cell): +class BertTrainOneStepCell(nn.TrainOneStepCell): """ Encapsulation class of bert network training. diff --git a/model_zoo/official/nlp/bert_thor/src/bert_for_pre_training.py b/model_zoo/official/nlp/bert_thor/src/bert_for_pre_training.py index 632f8825bd..303b04937f 100644 --- a/model_zoo/official/nlp/bert_thor/src/bert_for_pre_training.py +++ b/model_zoo/official/nlp/bert_thor/src/bert_for_pre_training.py @@ -275,7 +275,7 @@ class BertNetworkWithLoss(nn.Cell): return self.cast(total_loss, mstype.float32) -class BertTrainOneStepCell(nn.Cell): +class BertTrainOneStepCell(nn.TrainOneStepCell): """ Encapsulation class of bert network training. @@ -287,7 +287,6 @@ class BertTrainOneStepCell(nn.Cell): optimizer (Optimizer): Optimizer for updating the weights. sens (Number): The adjust parameter. Default: 1.0. """ - def __init__(self, network, optimizer, sens=1.0): super(BertTrainOneStepCell, self).__init__(network, optimizer, sens) self.cast = P.Cast()