|
|
@ -275,7 +275,7 @@ class BertNetworkWithLoss(nn.Cell):
|
|
|
|
return self.cast(total_loss, mstype.float32)
|
|
|
|
return self.cast(total_loss, mstype.float32)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class BertTrainOneStepCell(nn.Cell):
|
|
|
|
class BertTrainOneStepCell(nn.TrainOneStepCell):
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
Encapsulation class of bert network training.
|
|
|
|
Encapsulation class of bert network training.
|
|
|
|
|
|
|
|
|
|
|
@ -287,7 +287,6 @@ class BertTrainOneStepCell(nn.Cell):
|
|
|
|
optimizer (Optimizer): Optimizer for updating the weights.
|
|
|
|
optimizer (Optimizer): Optimizer for updating the weights.
|
|
|
|
sens (Number): The adjust parameter. Default: 1.0.
|
|
|
|
sens (Number): The adjust parameter. Default: 1.0.
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
def __init__(self, network, optimizer, sens=1.0):
|
|
|
|
def __init__(self, network, optimizer, sens=1.0):
|
|
|
|
super(BertTrainOneStepCell, self).__init__(network, optimizer, sens)
|
|
|
|
super(BertTrainOneStepCell, self).__init__(network, optimizer, sens)
|
|
|
|
self.cast = P.Cast()
|
|
|
|
self.cast = P.Cast()
|
|
|
|