diff --git a/mindspore/ops/_grad/grad_nn_ops.py b/mindspore/ops/_grad/grad_nn_ops.py index db72b775fa..bbbc9f48c2 100755 --- a/mindspore/ops/_grad/grad_nn_ops.py +++ b/mindspore/ops/_grad/grad_nn_ops.py @@ -1090,19 +1090,25 @@ def get_bprop_ce_with_logits_loss(self): add = P.Add() sub = P.Sub() size = P.Size() + neg = P.Neg() + log = P.Log() def bprop(predict, target, weight, pos_weight, out, dout): sigmoid_input = sigmoid(predict) if pos_weight is not None: t = mul(target, pos_weight) dx = mul(sub(mul(sub(add(t, 1), target), sigmoid_input), t), dout) + grad_target = mul(sub(log(sub(1, sigmoid_input)), mul(pos_weight, log(sigmoid_input))), dout) else: dx = mul((sigmoid_input - target), dout) + grad_target = mul(predict, neg(dout)) if weight is not None: dx = mul(dx, weight) + grad_target = mul(grad_target, weight) if reduction == 'mean': dx = dx / size(dx) - return dx, zeros_like(target), zeros_like(weight), zeros_like(pos_weight) + grad_target = grad_target / size(target) + return dx, grad_target, zeros_like(weight), zeros_like(pos_weight) return bprop