|
|
|
@ -216,6 +216,7 @@ class BertTrainWithLossScaleCell(nn.Cell):
|
|
|
|
|
def __init__(self, network, optimizer, scale_update_cell=None):
|
|
|
|
|
super(BertTrainWithLossScaleCell, self).__init__(auto_prefix=False)
|
|
|
|
|
self.network = network
|
|
|
|
|
self.network.set_grad()
|
|
|
|
|
self.weights = optimizer.parameters
|
|
|
|
|
self.optimizer = optimizer
|
|
|
|
|
self.grad = C.GradOperation(get_by_list=True,
|
|
|
|
@ -306,6 +307,7 @@ class BertTrainCell(nn.Cell):
|
|
|
|
|
def __init__(self, network, optimizer, sens=1.0):
|
|
|
|
|
super(BertTrainCell, self).__init__(auto_prefix=False)
|
|
|
|
|
self.network = network
|
|
|
|
|
self.network.set_grad()
|
|
|
|
|
self.weights = optimizer.parameters
|
|
|
|
|
self.optimizer = optimizer
|
|
|
|
|
self.sens = sens
|
|
|
|
@ -470,6 +472,7 @@ class BertEvaluationWithLossScaleCell(nn.Cell):
|
|
|
|
|
def __init__(self, network, optimizer, scale_update_cell=None):
|
|
|
|
|
super(BertEvaluationWithLossScaleCell, self).__init__(auto_prefix=False)
|
|
|
|
|
self.network = network
|
|
|
|
|
self.network.set_grad()
|
|
|
|
|
self.weights = optimizer.parameters
|
|
|
|
|
self.optimizer = optimizer
|
|
|
|
|
self.grad = C.GradOperation(get_by_list=True,
|
|
|
|
@ -556,6 +559,7 @@ class BertEvaluationCell(nn.Cell):
|
|
|
|
|
def __init__(self, network, optimizer, sens=1.0):
|
|
|
|
|
super(BertEvaluationCell, self).__init__(auto_prefix=False)
|
|
|
|
|
self.network = network
|
|
|
|
|
self.network.set_grad()
|
|
|
|
|
self.weights = optimizer.parameters
|
|
|
|
|
self.optimizer = optimizer
|
|
|
|
|
self.sens = sens
|
|
|
|
|