|
|
|
@ -341,7 +341,6 @@ class BertTrainWithLossScaleCell(nn.Cell):
|
|
|
|
|
self.get_status = P.NPUGetFloatStatus()
|
|
|
|
|
self.clear_before_grad = P.NPUClearFloatStatus()
|
|
|
|
|
self.reduce_sum = P.ReduceSum(keep_dims=False)
|
|
|
|
|
self.depend_parameter_use = P.ControlDepend(depend_mode=1)
|
|
|
|
|
self.base = Tensor(1, mstype.float32)
|
|
|
|
|
self.less_equal = P.LessEqual()
|
|
|
|
|
self.hyper_map = C.HyperMap()
|
|
|
|
@ -381,24 +380,24 @@ class BertTrainWithLossScaleCell(nn.Cell):
|
|
|
|
|
saved = ()
|
|
|
|
|
for i in range(self.length):
|
|
|
|
|
saved = saved + (F.assign(self.saved_params[i], weights[i]),)
|
|
|
|
|
assign_embedding = ()
|
|
|
|
|
|
|
|
|
|
for i in range(self.quant_embedding_list_length):
|
|
|
|
|
quant_embedding = self.quantize_embedding(weights[self.quant_embedding_list[i]])
|
|
|
|
|
assign_embedding = assign_embedding + (F.assign(weights[self.quant_embedding_list[i]], quant_embedding),)
|
|
|
|
|
F.control_depend(saved, assign_embedding[i])
|
|
|
|
|
assign_weight = ()
|
|
|
|
|
quant_embedding = F.depend(quant_embedding, saved)
|
|
|
|
|
assign_embedding = F.assign(weights[self.quant_embedding_list[i]], quant_embedding)
|
|
|
|
|
input_ids = F.depend(input_ids, assign_embedding)
|
|
|
|
|
|
|
|
|
|
for i in range(self.quant_weight_list_length):
|
|
|
|
|
quant_weight = self.quantize_weight(weights[self.quant_weight_list[i]])
|
|
|
|
|
assign_weight = assign_weight + (F.assign(weights[self.quant_weight_list[i]], quant_weight),)
|
|
|
|
|
F.control_depend(saved, assign_weight[i])
|
|
|
|
|
for i in range(self.quant_embedding_list_length):
|
|
|
|
|
F.control_depend(assign_embedding[i], input_ids)
|
|
|
|
|
for i in range(self.quant_weight_list_length):
|
|
|
|
|
F.control_depend(assign_weight[i], input_ids)
|
|
|
|
|
quant_weight = F.depend(quant_weight, saved)
|
|
|
|
|
assign_weight = F.assign(weights[self.quant_weight_list[i]], quant_weight)
|
|
|
|
|
input_ids = F.depend(input_ids, assign_weight)
|
|
|
|
|
|
|
|
|
|
if sens is None:
|
|
|
|
|
scaling_sens = self.loss_scale
|
|
|
|
|
else:
|
|
|
|
|
scaling_sens = sens
|
|
|
|
|
|
|
|
|
|
# alloc status and clear should be right before grad operation
|
|
|
|
|
init = self.alloc_status()
|
|
|
|
|
self.clear_before_grad(init)
|
|
|
|
@ -408,15 +407,15 @@ class BertTrainWithLossScaleCell(nn.Cell):
|
|
|
|
|
label_ids,
|
|
|
|
|
self.cast(scaling_sens,
|
|
|
|
|
mstype.float32))
|
|
|
|
|
F.control_depend(input_ids, grads)
|
|
|
|
|
# apply grad reducer on grads
|
|
|
|
|
grads = self.grad_reducer(grads)
|
|
|
|
|
grads = self.hyper_map(F.partial(grad_scale, scaling_sens * self.degree), grads)
|
|
|
|
|
grads = self.hyper_map(F.partial(clip_grad, self.clip_type, self.clip_value), grads)
|
|
|
|
|
restore = ()
|
|
|
|
|
for i in range(self.length):
|
|
|
|
|
weights[i] = F.depend(weights[i], grads)
|
|
|
|
|
restore = restore + (F.assign(weights[i], self.saved_params[i]),)
|
|
|
|
|
F.control_depend(grads, restore[i])
|
|
|
|
|
|
|
|
|
|
self.get_status(init)
|
|
|
|
|
flag_sum = self.reduce_sum(init, (0,))
|
|
|
|
|
if self.is_distributed:
|
|
|
|
@ -432,8 +431,7 @@ class BertTrainWithLossScaleCell(nn.Cell):
|
|
|
|
|
succ = False
|
|
|
|
|
else:
|
|
|
|
|
succ = self.optimizer(grads)
|
|
|
|
|
for i in range(self.length):
|
|
|
|
|
F.control_depend(restore[i], succ)
|
|
|
|
|
succ = F.depend(succ, restore)
|
|
|
|
|
return succ
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -495,35 +493,33 @@ class BertTrainCell(nn.Cell):
|
|
|
|
|
saved = ()
|
|
|
|
|
for i in range(self.length):
|
|
|
|
|
saved = saved + (F.assign(self.saved_params[i], weights[i]),)
|
|
|
|
|
assign_embedding = ()
|
|
|
|
|
|
|
|
|
|
for i in range(self.quant_embedding_list_length):
|
|
|
|
|
quant_embedding = self.quantize_embedding(weights[self.quant_embedding_list[i]])
|
|
|
|
|
assign_embedding = assign_embedding + (F.assign(weights[self.quant_embedding_list[i]], quant_embedding),)
|
|
|
|
|
F.control_depend(saved, assign_embedding[i])
|
|
|
|
|
assign_weight = ()
|
|
|
|
|
quant_embedding = F.depend(quant_embedding, saved)
|
|
|
|
|
assign_embedding = F.assign(weights[self.quant_embedding_list[i]], quant_embedding)
|
|
|
|
|
input_ids = F.depend(input_ids, assign_embedding)
|
|
|
|
|
|
|
|
|
|
for i in range(self.quant_weight_list_length):
|
|
|
|
|
quant_weight = self.quantize_weight(weights[self.quant_weight_list[i]])
|
|
|
|
|
assign_weight = assign_weight + (F.assign(weights[self.quant_weight_list[i]], quant_weight),)
|
|
|
|
|
F.control_depend(saved, assign_weight[i])
|
|
|
|
|
for i in range(self.quant_embedding_list_length):
|
|
|
|
|
F.control_depend(assign_embedding[i], input_ids)
|
|
|
|
|
for i in range(self.quant_weight_list_length):
|
|
|
|
|
F.control_depend(assign_weight[i], input_ids)
|
|
|
|
|
quant_weight = F.depend(quant_weight, saved)
|
|
|
|
|
assign_weight = F.assign(weights[self.quant_weight_list[i]], quant_weight)
|
|
|
|
|
input_ids = F.depend(input_ids, assign_weight)
|
|
|
|
|
|
|
|
|
|
grads = self.grad(self.network, weights)(input_ids,
|
|
|
|
|
input_mask,
|
|
|
|
|
token_type_id,
|
|
|
|
|
label_ids,
|
|
|
|
|
self.cast(F.tuple_to_array((self.sens,)),
|
|
|
|
|
mstype.float32))
|
|
|
|
|
F.control_depend(input_ids, grads)
|
|
|
|
|
# apply grad reducer on grads
|
|
|
|
|
grads = self.grad_reducer(grads)
|
|
|
|
|
grads = self.hyper_map(F.partial(clip_grad, self.clip_type, self.clip_value), grads)
|
|
|
|
|
restore = ()
|
|
|
|
|
for i in range(self.length):
|
|
|
|
|
weights[i] = F.depend(weights[i], grads)
|
|
|
|
|
restore = restore + (F.assign(weights[i], self.saved_params[i]),)
|
|
|
|
|
F.control_depend(grads, restore[i])
|
|
|
|
|
|
|
|
|
|
succ = self.optimizer(grads)
|
|
|
|
|
for i in range(self.length):
|
|
|
|
|
F.control_depend(restore[i], succ)
|
|
|
|
|
succ = F.depend(succ, restore)
|
|
|
|
|
return succ
|
|
|
|
|