diff --git a/mindspore/ops/_grad/grad_nn_ops.py b/mindspore/ops/_grad/grad_nn_ops.py index 4934a13381..855a68d51b 100755 --- a/mindspore/ops/_grad/grad_nn_ops.py +++ b/mindspore/ops/_grad/grad_nn_ops.py @@ -14,6 +14,7 @@ # ============================================================================ """Define the grad rules of neural network related operations.""" +import math import numpy as np from mindspore.ops import _selected_grad_ops as SG from mindspore.ops.primitive import constexpr @@ -628,19 +629,62 @@ def get_bprop_onehot(self): return bprop +@constexpr +def _range_op(start, limit, delta, dtype): + """helper function for Grad TopK""" + range_op = inner.Range(float(start), float(limit), float(delta)) + length_input = math.ceil((limit - start) / delta) + input_tensor = Tensor(list(range(length_input)), dtype) + range_out = range_op(input_tensor) + return range_out + +@constexpr +def _get_1d_shape(in_shape): + """helper function for Grad TopK""" + out_shape = 1 + for i in in_shape: + out_shape *= i + return (out_shape,) + @bprop_getters.register(P.TopK) def get_bprop_top_kv2(self): """Grad definition for `TopK` operation.""" scatter = P.ScatterNd() expand_dims = P.ExpandDims() shape_op = P.Shape() + reshape_op = P.Reshape() + dtype = P.DType() def bprop(input_x, k, out, dout): + + # (n1, n2, ...., n_p), in_lastdim = n_p + in_shape = shape_op(input_x) + in_lastdim = in_shape[-1] + + # (n_1, ... n_(p-1), k), ind_lastdim = k indices = out[1] - indices = expand_dims(indices, -1) - updates = dout[0] - shapes = shape_op(input_x) - return scatter(indices, updates, shapes), zeros_like(k) + ind_shape = shape_op(indices) + ind_lastdim = ind_shape[-1] + + # (n_1*n_2..*n_(p-1), k), outerdim = n_1*n_2..*n_(p-1) + ind_2d = reshape_op(indices, (-1, ind_lastdim)) + outerdim = shape_op(ind_2d)[0] + + # [0, outterdim, 2*outerdim, ..., (k-1)*outerdim] + indices_dtype = dtype(indices) + range_flatten_index = _range_op(0, outerdim * in_lastdim, in_lastdim, indices_dtype) + + # expand_dims to (k, 1), then broadcast + ind = reshape_op(ind_2d + expand_dims(range_flatten_index, -1), (-1,)) + in_shape_1d = _get_1d_shape(in_shape) + + out_grad = reshape_op( + scatter( + expand_dims(ind, -1), + reshape_op(dout[0], (-1,)), + in_shape_1d), + in_shape) + return out_grad, zeros_like(k) return bprop