From dc78af73652c0b5cdf90027ae2ceb7929382d313 Mon Sep 17 00:00:00 2001 From: "wangnan39@huawei.com" Date: Tue, 15 Sep 2020 17:49:16 +0800 Subject: [PATCH] fix bug in bert --- model_zoo/official/nlp/bert/src/bert_for_pre_training.py | 2 +- model_zoo/official/nlp/bert_thor/src/bert_for_pre_training.py | 3 +-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/model_zoo/official/nlp/bert/src/bert_for_pre_training.py b/model_zoo/official/nlp/bert/src/bert_for_pre_training.py index 4459d9842c..64f6c96a72 100644 --- a/model_zoo/official/nlp/bert/src/bert_for_pre_training.py +++ b/model_zoo/official/nlp/bert/src/bert_for_pre_training.py @@ -258,7 +258,7 @@ class BertNetworkWithLoss(nn.Cell): return self.cast(total_loss, mstype.float32) -class BertTrainOneStepCell(nn.Cell): +class BertTrainOneStepCell(nn.TrainOneStepCell): """ Encapsulation class of bert network training. diff --git a/model_zoo/official/nlp/bert_thor/src/bert_for_pre_training.py b/model_zoo/official/nlp/bert_thor/src/bert_for_pre_training.py index 632f8825bd..303b04937f 100644 --- a/model_zoo/official/nlp/bert_thor/src/bert_for_pre_training.py +++ b/model_zoo/official/nlp/bert_thor/src/bert_for_pre_training.py @@ -275,7 +275,7 @@ class BertNetworkWithLoss(nn.Cell): return self.cast(total_loss, mstype.float32) -class BertTrainOneStepCell(nn.Cell): +class BertTrainOneStepCell(nn.TrainOneStepCell): """ Encapsulation class of bert network training. @@ -287,7 +287,6 @@ class BertTrainOneStepCell(nn.Cell): optimizer (Optimizer): Optimizer for updating the weights. sens (Number): The adjust parameter. Default: 1.0. """ - def __init__(self, network, optimizer, sens=1.0): super(BertTrainOneStepCell, self).__init__(network, optimizer, sens) self.cast = P.Cast()