Add grad of sort.

pull/11128/head
liuxiao93 5 years ago
parent 3945bdcabb
commit aacad990c3

@ -15,8 +15,10 @@
"""array_ops"""
import numpy as np
import mindspore as ms
from mindspore.ops import composite as C
from mindspore.common.tensor import Tensor
from .. import operations as P
from ..operations import _grad_ops as G
from ..operations import _inner_ops as inner
@ -459,6 +461,87 @@ def get_bprop_sparse_gather_v2(self):
return bprop
@constexpr
def _range_op(start, limit, delta, dtype):
"""helper function for grad of Sort"""
output_tensor = Tensor(list(range(start, limit, delta)), dtype)
return output_tensor
@constexpr
def _get_1d_shape(in_shape):
"""helper function for grad of Sort"""
out_shape = 1
for i in in_shape:
out_shape *= i
return (out_shape,)
@constexpr
def _get_transposition(axis, rank):
"""helper function for grad of Sort"""
if axis < 0:
axis += rank
transposition = np.r_[np.arange(axis), [rank - 1], np.arange(axis + 1, rank - 1), [axis]]
trans = tuple(transposition.tolist())
return trans
@bprop_getters.register(P.Sort)
def get_bprop_sort(self):
"""Grad definition for `Sort` operation."""
axis = self.axis
descending = self.descending
scatter = P.ScatterNd()
expand_dims = P.ExpandDims()
reshape_op = P.Reshape()
dtype = P.DType()
topk = P.TopK()
neg = P.Neg()
tranpose = P.Transpose()
def bprop(input_x, out, dout):
x_shape = input_x.shape
k = x_shape[axis]
rank = F.rank(input_x)
dvalue = dout[0]
if not descending:
input_x = neg(input_x)
dvalue = neg(dvalue)
if axis == -1 or (axis + 1) == rank:
transposition = None
top_k_input = input_x
else:
transposition = _get_transposition(axis, rank)
top_k_input = tranpose(input_x, transposition)
_, indices = topk(top_k_input, k)
ind_shape = indices.shape
top_k_input_shape = top_k_input.shape
in_lastdim = top_k_input_shape[-1]
ind_lastdim = ind_shape[-1]
ind_2d = reshape_op(indices, (-1, ind_lastdim))
outer_dim = ind_2d.shape[0]
# [0, outterdim, 2*outerdim, ..., (k-1)*outerdim]
indices_dtype = dtype(indices)
range_flatten_index = _range_op(0, outer_dim * in_lastdim, in_lastdim, indices_dtype)
# expand_dims to (k, 1), then broadcast
ind = reshape_op(ind_2d + expand_dims(range_flatten_index, -1), (-1,))
x_shape_1d = _get_1d_shape(top_k_input_shape)
if transposition is not None:
dvalue = tranpose(dvalue, invert_permutation(transposition))
out_grad = reshape_op(
scatter(expand_dims(ind, -1), reshape_op(dvalue, (-1,)), x_shape_1d), top_k_input_shape)
dx = tranpose(out_grad, invert_permutation(transposition))
else:
dx = reshape_op(scatter(expand_dims(ind, -1), reshape_op(dvalue, (-1,)), x_shape_1d), top_k_input_shape)
if not descending:
dx = neg(dx)
return (dx,)
return bprop
@bprop_getters.register(P.Identity)
def get_bprop_identity(self):
"""Generate bprop for Identity"""
@ -475,6 +558,7 @@ def get_bprop_range(self):
def bprop(x, out, dout):
return (zeros_like(x),)
return bprop
@ -506,7 +590,7 @@ def get_bprop_reverse_v2(self):
dx = reverse_grad(dout)
return (dx,)
return bprop
return bprop
@bprop_getters.register(P.Unpack)
def get_bprop_unpack(self):

@ -1708,6 +1708,10 @@ test_case_nn_ops = [
'desc_inputs': [[20, 20, 10]],
'desc_bprop': [[20, 20, 5]],
'skip': ['backward']}),
('Sort', {
'block': P.Sort(),
'desc_inputs': [[2, 3, 4]],
'desc_bprop': [[2, 3, 4], ([2, 3, 4], {'dtype': np.int32})]}),
('GatherV2_0', {
'block': P.GatherV2(),
'desc_const': [0],

Loading…
Cancel
Save