fix bug of auto control depend for bert pre training

add comment
pull/915/head
huangdongrun 6 years ago committed by chang zherui
parent 04471939f5
commit 10ee3fbca5

@ -403,9 +403,6 @@ class BertTrainOneStepWithLossScaleCell(nn.Cell):
sens=None): sens=None):
"""Defines the computation performed.""" """Defines the computation performed."""
weights = self.weights weights = self.weights
# alloc status
init = self.alloc_status()
self.clear_before_grad(init)
loss = self.network(input_ids, loss = self.network(input_ids,
input_mask, input_mask,
token_type_id, token_type_id,
@ -417,6 +414,9 @@ class BertTrainOneStepWithLossScaleCell(nn.Cell):
scaling_sens = self.loss_scale scaling_sens = self.loss_scale
else: else:
scaling_sens = sens scaling_sens = sens
# alloc status and clear should be right before gradoperation
init = self.alloc_status()
self.clear_before_grad(init)
grads = self.grad(self.network, weights)(input_ids, grads = self.grad(self.network, weights)(input_ids,
input_mask, input_mask,
token_type_id, token_type_id,

Loading…
Cancel
Save