!13877 remove ControlDepend from ternarybert

From: @huangbingjian
Reviewed-by: @ginfung,@hwhewei
Signed-off-by: @hwhewei
pull/13877/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit 1732ed34b6

@ -423,9 +423,9 @@ class Depend(Primitive):
In order to ensure that operator A is executed before operator B, it is recommended to
insert the Depend operator between operators A and B. The usage method is as follows::
out_a = A(in_a)
in_b = Depend(in_b, out_a)
out_b = B(in_b)
a = A(x) ---> a = A(x)
b = B(y) ---> y = Depend(y, a)
---> b = B(y)
Inputs:
- **value** (Tensor) - the real value to return for depend operator.

@ -377,21 +377,16 @@ class BertTrainWithLossScaleCell(nn.Cell):
sens=None):
"""Defines the computation performed."""
weights = self.weights
saved = ()
for i in range(self.length):
saved = saved + (F.assign(self.saved_params[i], weights[i]),)
F.assign(self.saved_params[i], weights[i])
for i in range(self.quant_embedding_list_length):
quant_embedding = self.quantize_embedding(weights[self.quant_embedding_list[i]])
quant_embedding = F.depend(quant_embedding, saved)
assign_embedding = F.assign(weights[self.quant_embedding_list[i]], quant_embedding)
input_ids = F.depend(input_ids, assign_embedding)
F.assign(weights[self.quant_embedding_list[i]], quant_embedding)
for i in range(self.quant_weight_list_length):
quant_weight = self.quantize_weight(weights[self.quant_weight_list[i]])
quant_weight = F.depend(quant_weight, saved)
assign_weight = F.assign(weights[self.quant_weight_list[i]], quant_weight)
input_ids = F.depend(input_ids, assign_weight)
F.assign(weights[self.quant_weight_list[i]], quant_weight)
if sens is None:
scaling_sens = self.loss_scale
@ -411,10 +406,10 @@ class BertTrainWithLossScaleCell(nn.Cell):
grads = self.grad_reducer(grads)
grads = self.hyper_map(F.partial(grad_scale, scaling_sens * self.degree), grads)
grads = self.hyper_map(F.partial(clip_grad, self.clip_type, self.clip_value), grads)
restore = ()
for i in range(self.length):
weights[i] = F.depend(weights[i], grads)
restore = restore + (F.assign(weights[i], self.saved_params[i]),)
param = F.depend(self.saved_params[i], grads)
F.assign(weights[i], param)
self.get_status(init)
flag_sum = self.reduce_sum(init, (0,))
@ -431,7 +426,6 @@ class BertTrainWithLossScaleCell(nn.Cell):
succ = False
else:
succ = self.optimizer(grads)
succ = F.depend(succ, restore)
return succ
@ -490,21 +484,16 @@ class BertTrainCell(nn.Cell):
label_ids):
"""Defines the computation performed."""
weights = self.weights
saved = ()
for i in range(self.length):
saved = saved + (F.assign(self.saved_params[i], weights[i]),)
F.assign(self.saved_params[i], weights[i])
for i in range(self.quant_embedding_list_length):
quant_embedding = self.quantize_embedding(weights[self.quant_embedding_list[i]])
quant_embedding = F.depend(quant_embedding, saved)
assign_embedding = F.assign(weights[self.quant_embedding_list[i]], quant_embedding)
input_ids = F.depend(input_ids, assign_embedding)
F.assign(weights[self.quant_embedding_list[i]], quant_embedding)
for i in range(self.quant_weight_list_length):
quant_weight = self.quantize_weight(weights[self.quant_weight_list[i]])
quant_weight = F.depend(quant_weight, saved)
assign_weight = F.assign(weights[self.quant_weight_list[i]], quant_weight)
input_ids = F.depend(input_ids, assign_weight)
F.assign(weights[self.quant_weight_list[i]], quant_weight)
grads = self.grad(self.network, weights)(input_ids,
input_mask,
@ -515,11 +504,10 @@ class BertTrainCell(nn.Cell):
# apply grad reducer on grads
grads = self.grad_reducer(grads)
grads = self.hyper_map(F.partial(clip_grad, self.clip_type, self.clip_value), grads)
restore = ()
for i in range(self.length):
weights[i] = F.depend(weights[i], grads)
restore = restore + (F.assign(weights[i], self.saved_params[i]),)
param = F.depend(self.saved_params[i], grads)
F.assign(weights[i], param)
succ = self.optimizer(grads)
succ = F.depend(succ, restore)
return succ

Loading…
Cancel
Save