diff --git a/mindspore/ops/operations/other_ops.py b/mindspore/ops/operations/other_ops.py index 9b0216557b..9d9f4467e0 100644 --- a/mindspore/ops/operations/other_ops.py +++ b/mindspore/ops/operations/other_ops.py @@ -423,9 +423,9 @@ class Depend(Primitive): In order to ensure that operator A is executed before operator B, it is recommended to insert the Depend operator between operators A and B. The usage method is as follows:: - out_a = A(in_a) - in_b = Depend(in_b, out_a) - out_b = B(in_b) + a = A(x) ---> a = A(x) + b = B(y) ---> y = Depend(y, a) + ---> b = B(y) Inputs: - **value** (Tensor) - the real value to return for depend operator. diff --git a/model_zoo/research/nlp/ternarybert/src/cell_wrapper.py b/model_zoo/research/nlp/ternarybert/src/cell_wrapper.py index 4b585f1ab3..04906f434c 100644 --- a/model_zoo/research/nlp/ternarybert/src/cell_wrapper.py +++ b/model_zoo/research/nlp/ternarybert/src/cell_wrapper.py @@ -377,21 +377,16 @@ class BertTrainWithLossScaleCell(nn.Cell): sens=None): """Defines the computation performed.""" weights = self.weights - saved = () for i in range(self.length): - saved = saved + (F.assign(self.saved_params[i], weights[i]),) + F.assign(self.saved_params[i], weights[i]) for i in range(self.quant_embedding_list_length): quant_embedding = self.quantize_embedding(weights[self.quant_embedding_list[i]]) - quant_embedding = F.depend(quant_embedding, saved) - assign_embedding = F.assign(weights[self.quant_embedding_list[i]], quant_embedding) - input_ids = F.depend(input_ids, assign_embedding) + F.assign(weights[self.quant_embedding_list[i]], quant_embedding) for i in range(self.quant_weight_list_length): quant_weight = self.quantize_weight(weights[self.quant_weight_list[i]]) - quant_weight = F.depend(quant_weight, saved) - assign_weight = F.assign(weights[self.quant_weight_list[i]], quant_weight) - input_ids = F.depend(input_ids, assign_weight) + F.assign(weights[self.quant_weight_list[i]], quant_weight) if sens is None: scaling_sens = self.loss_scale @@ -411,10 +406,10 @@ class BertTrainWithLossScaleCell(nn.Cell): grads = self.grad_reducer(grads) grads = self.hyper_map(F.partial(grad_scale, scaling_sens * self.degree), grads) grads = self.hyper_map(F.partial(clip_grad, self.clip_type, self.clip_value), grads) - restore = () + for i in range(self.length): - weights[i] = F.depend(weights[i], grads) - restore = restore + (F.assign(weights[i], self.saved_params[i]),) + param = F.depend(self.saved_params[i], grads) + F.assign(weights[i], param) self.get_status(init) flag_sum = self.reduce_sum(init, (0,)) @@ -431,7 +426,6 @@ class BertTrainWithLossScaleCell(nn.Cell): succ = False else: succ = self.optimizer(grads) - succ = F.depend(succ, restore) return succ @@ -490,21 +484,16 @@ class BertTrainCell(nn.Cell): label_ids): """Defines the computation performed.""" weights = self.weights - saved = () for i in range(self.length): - saved = saved + (F.assign(self.saved_params[i], weights[i]),) + F.assign(self.saved_params[i], weights[i]) for i in range(self.quant_embedding_list_length): quant_embedding = self.quantize_embedding(weights[self.quant_embedding_list[i]]) - quant_embedding = F.depend(quant_embedding, saved) - assign_embedding = F.assign(weights[self.quant_embedding_list[i]], quant_embedding) - input_ids = F.depend(input_ids, assign_embedding) + F.assign(weights[self.quant_embedding_list[i]], quant_embedding) for i in range(self.quant_weight_list_length): quant_weight = self.quantize_weight(weights[self.quant_weight_list[i]]) - quant_weight = F.depend(quant_weight, saved) - assign_weight = F.assign(weights[self.quant_weight_list[i]], quant_weight) - input_ids = F.depend(input_ids, assign_weight) + F.assign(weights[self.quant_weight_list[i]], quant_weight) grads = self.grad(self.network, weights)(input_ids, input_mask, @@ -515,11 +504,10 @@ class BertTrainCell(nn.Cell): # apply grad reducer on grads grads = self.grad_reducer(grads) grads = self.hyper_map(F.partial(clip_grad, self.clip_type, self.clip_value), grads) - restore = () + for i in range(self.length): - weights[i] = F.depend(weights[i], grads) - restore = restore + (F.assign(weights[i], self.saved_params[i]),) + param = F.depend(self.saved_params[i], grads) + F.assign(weights[i], param) succ = self.optimizer(grads) - succ = F.depend(succ, restore) return succ