|
|
|
@ -403,9 +403,6 @@ class BertTrainOneStepWithLossScaleCell(nn.Cell):
|
|
|
|
|
sens=None):
|
|
|
|
|
"""Defines the computation performed."""
|
|
|
|
|
weights = self.weights
|
|
|
|
|
# alloc status
|
|
|
|
|
init = self.alloc_status()
|
|
|
|
|
self.clear_before_grad(init)
|
|
|
|
|
loss = self.network(input_ids,
|
|
|
|
|
input_mask,
|
|
|
|
|
token_type_id,
|
|
|
|
@ -417,6 +414,9 @@ class BertTrainOneStepWithLossScaleCell(nn.Cell):
|
|
|
|
|
scaling_sens = self.loss_scale
|
|
|
|
|
else:
|
|
|
|
|
scaling_sens = sens
|
|
|
|
|
# alloc status and clear should be right before gradoperation
|
|
|
|
|
init = self.alloc_status()
|
|
|
|
|
self.clear_before_grad(init)
|
|
|
|
|
grads = self.grad(self.network, weights)(input_ids,
|
|
|
|
|
input_mask,
|
|
|
|
|
token_type_id,
|
|
|
|
|