From aacad990c3a23cd9c6631585a79ada41318930b6 Mon Sep 17 00:00:00 2001 From: liuxiao93 Date: Fri, 8 Jan 2021 10:49:14 +0800 Subject: [PATCH] Add grad of sort. --- mindspore/ops/_grad/grad_array_ops.py | 86 ++++++++++++++++++++++++++- tests/ut/python/ops/test_ops.py | 4 ++ 2 files changed, 89 insertions(+), 1 deletion(-) diff --git a/mindspore/ops/_grad/grad_array_ops.py b/mindspore/ops/_grad/grad_array_ops.py index 4314489f42..e1dfeac2cd 100644 --- a/mindspore/ops/_grad/grad_array_ops.py +++ b/mindspore/ops/_grad/grad_array_ops.py @@ -15,8 +15,10 @@ """array_ops""" +import numpy as np import mindspore as ms from mindspore.ops import composite as C +from mindspore.common.tensor import Tensor from .. import operations as P from ..operations import _grad_ops as G from ..operations import _inner_ops as inner @@ -459,6 +461,87 @@ def get_bprop_sparse_gather_v2(self): return bprop +@constexpr +def _range_op(start, limit, delta, dtype): + """helper function for grad of Sort""" + output_tensor = Tensor(list(range(start, limit, delta)), dtype) + return output_tensor + +@constexpr +def _get_1d_shape(in_shape): + """helper function for grad of Sort""" + out_shape = 1 + for i in in_shape: + out_shape *= i + return (out_shape,) + +@constexpr +def _get_transposition(axis, rank): + """helper function for grad of Sort""" + if axis < 0: + axis += rank + transposition = np.r_[np.arange(axis), [rank - 1], np.arange(axis + 1, rank - 1), [axis]] + trans = tuple(transposition.tolist()) + return trans + +@bprop_getters.register(P.Sort) +def get_bprop_sort(self): + """Grad definition for `Sort` operation.""" + axis = self.axis + descending = self.descending + scatter = P.ScatterNd() + expand_dims = P.ExpandDims() + reshape_op = P.Reshape() + dtype = P.DType() + topk = P.TopK() + neg = P.Neg() + tranpose = P.Transpose() + + def bprop(input_x, out, dout): + x_shape = input_x.shape + k = x_shape[axis] + rank = F.rank(input_x) + dvalue = dout[0] + if not descending: + input_x = neg(input_x) + dvalue = neg(dvalue) + if axis == -1 or (axis + 1) == rank: + transposition = None + top_k_input = input_x + else: + transposition = _get_transposition(axis, rank) + top_k_input = tranpose(input_x, transposition) + + _, indices = topk(top_k_input, k) + ind_shape = indices.shape + top_k_input_shape = top_k_input.shape + in_lastdim = top_k_input_shape[-1] + ind_lastdim = ind_shape[-1] + ind_2d = reshape_op(indices, (-1, ind_lastdim)) + outer_dim = ind_2d.shape[0] + + # [0, outterdim, 2*outerdim, ..., (k-1)*outerdim] + indices_dtype = dtype(indices) + range_flatten_index = _range_op(0, outer_dim * in_lastdim, in_lastdim, indices_dtype) + + # expand_dims to (k, 1), then broadcast + ind = reshape_op(ind_2d + expand_dims(range_flatten_index, -1), (-1,)) + x_shape_1d = _get_1d_shape(top_k_input_shape) + + if transposition is not None: + dvalue = tranpose(dvalue, invert_permutation(transposition)) + out_grad = reshape_op( + scatter(expand_dims(ind, -1), reshape_op(dvalue, (-1,)), x_shape_1d), top_k_input_shape) + dx = tranpose(out_grad, invert_permutation(transposition)) + else: + dx = reshape_op(scatter(expand_dims(ind, -1), reshape_op(dvalue, (-1,)), x_shape_1d), top_k_input_shape) + if not descending: + dx = neg(dx) + return (dx,) + + return bprop + + @bprop_getters.register(P.Identity) def get_bprop_identity(self): """Generate bprop for Identity""" @@ -475,6 +558,7 @@ def get_bprop_range(self): def bprop(x, out, dout): return (zeros_like(x),) + return bprop @@ -506,7 +590,7 @@ def get_bprop_reverse_v2(self): dx = reverse_grad(dout) return (dx,) - return bprop + return bprop @bprop_getters.register(P.Unpack) def get_bprop_unpack(self): diff --git a/tests/ut/python/ops/test_ops.py b/tests/ut/python/ops/test_ops.py index 33e2af9a6e..9b2eb338e7 100755 --- a/tests/ut/python/ops/test_ops.py +++ b/tests/ut/python/ops/test_ops.py @@ -1708,6 +1708,10 @@ test_case_nn_ops = [ 'desc_inputs': [[20, 20, 10]], 'desc_bprop': [[20, 20, 5]], 'skip': ['backward']}), + ('Sort', { + 'block': P.Sort(), + 'desc_inputs': [[2, 3, 4]], + 'desc_bprop': [[2, 3, 4], ([2, 3, 4], {'dtype': np.int32})]}), ('GatherV2_0', { 'block': P.GatherV2(), 'desc_const': [0],