fix bug in bert

pull/6272/head
wangnan39@huawei.com 4 years ago
parent 4abb33f151
commit dc78af7365

@ -258,7 +258,7 @@ class BertNetworkWithLoss(nn.Cell):
return self.cast(total_loss, mstype.float32)
class BertTrainOneStepCell(nn.Cell):
class BertTrainOneStepCell(nn.TrainOneStepCell):
"""
Encapsulation class of bert network training.

@ -275,7 +275,7 @@ class BertNetworkWithLoss(nn.Cell):
return self.cast(total_loss, mstype.float32)
class BertTrainOneStepCell(nn.Cell):
class BertTrainOneStepCell(nn.TrainOneStepCell):
"""
Encapsulation class of bert network training.
@ -287,7 +287,6 @@ class BertTrainOneStepCell(nn.Cell):
optimizer (Optimizer): Optimizer for updating the weights.
sens (Number): The adjust parameter. Default: 1.0.
"""
def __init__(self, network, optimizer, sens=1.0):
super(BertTrainOneStepCell, self).__init__(network, optimizer, sens)
self.cast = P.Cast()

Loading…
Cancel
Save