diff --git a/mindspore/ops/_grad/grad_implementations.py b/mindspore/ops/_grad/grad_implementations.py index 2ba3fbd55f..d6e9fe2b84 100644 --- a/mindspore/ops/_grad/grad_implementations.py +++ b/mindspore/ops/_grad/grad_implementations.py @@ -14,29 +14,40 @@ # ============================================================================ """bprop primitives""" +from ..operations import _grad_ops as G from .. import functional as F +from .. import operations as P from ..composite import multitype_ops as C from .grad_base import bprops +get_dtype = P.DType() # Unused parameters are placeholders. @bprops.register("MaximumGrad") def bprop_maximum_grad_grad(x, y, z, out, dout): """Backpropagator for primitive `MaximumGrad`.""" - return F.zeros_like(x), F.zeros_like(y), F.zeros_like(z) + out0 = F.cast(out[0] != 0, get_dtype(dout[0])) + out1 = F.cast(out[1] != 0, get_dtype(dout[1])) + dz = out0 * dout[0] + out1 * dout[1] + return F.zeros_like(x), F.zeros_like(y), dz @bprops.register("MinimumGrad") def bprop_minimum_grad_grad(x, y, z, out, dout): """Backpropagator for primitive `MinimumGrad`.""" - return F.zeros_like(x), F.zeros_like(y), F.zeros_like(z) + out0 = F.cast(out[0] != 0, get_dtype(dout[0])) + out1 = F.cast(out[1] != 0, get_dtype(dout[1])) + dz = out0 * dout[0] + out1 * dout[1] + return F.zeros_like(x), F.zeros_like(y), dz @bprops.register("ReluGrad") def bprop_relu_grad_grad(x, y, out, dout): """Backpropagator for primitive `ReluGrad`.""" - return F.zeros_like(x), F.zeros_like(y) + input_grad = G.ReluGrad() + dy = input_grad(dout, y) + return dy, F.zeros_like(y) @bprops.register("scalar_add")