|
|
@ -15,8 +15,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
"""array_ops"""
|
|
|
|
"""array_ops"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import numpy as np
|
|
|
|
import mindspore as ms
|
|
|
|
import mindspore as ms
|
|
|
|
from mindspore.ops import composite as C
|
|
|
|
from mindspore.ops import composite as C
|
|
|
|
|
|
|
|
from mindspore.common.tensor import Tensor
|
|
|
|
from .. import operations as P
|
|
|
|
from .. import operations as P
|
|
|
|
from ..operations import _grad_ops as G
|
|
|
|
from ..operations import _grad_ops as G
|
|
|
|
from ..operations import _inner_ops as inner
|
|
|
|
from ..operations import _inner_ops as inner
|
|
|
@ -459,6 +461,87 @@ def get_bprop_sparse_gather_v2(self):
|
|
|
|
return bprop
|
|
|
|
return bprop
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@constexpr
|
|
|
|
|
|
|
|
def _range_op(start, limit, delta, dtype):
|
|
|
|
|
|
|
|
"""helper function for grad of Sort"""
|
|
|
|
|
|
|
|
output_tensor = Tensor(list(range(start, limit, delta)), dtype)
|
|
|
|
|
|
|
|
return output_tensor
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@constexpr
|
|
|
|
|
|
|
|
def _get_1d_shape(in_shape):
|
|
|
|
|
|
|
|
"""helper function for grad of Sort"""
|
|
|
|
|
|
|
|
out_shape = 1
|
|
|
|
|
|
|
|
for i in in_shape:
|
|
|
|
|
|
|
|
out_shape *= i
|
|
|
|
|
|
|
|
return (out_shape,)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@constexpr
|
|
|
|
|
|
|
|
def _get_transposition(axis, rank):
|
|
|
|
|
|
|
|
"""helper function for grad of Sort"""
|
|
|
|
|
|
|
|
if axis < 0:
|
|
|
|
|
|
|
|
axis += rank
|
|
|
|
|
|
|
|
transposition = np.r_[np.arange(axis), [rank - 1], np.arange(axis + 1, rank - 1), [axis]]
|
|
|
|
|
|
|
|
trans = tuple(transposition.tolist())
|
|
|
|
|
|
|
|
return trans
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@bprop_getters.register(P.Sort)
|
|
|
|
|
|
|
|
def get_bprop_sort(self):
|
|
|
|
|
|
|
|
"""Grad definition for `Sort` operation."""
|
|
|
|
|
|
|
|
axis = self.axis
|
|
|
|
|
|
|
|
descending = self.descending
|
|
|
|
|
|
|
|
scatter = P.ScatterNd()
|
|
|
|
|
|
|
|
expand_dims = P.ExpandDims()
|
|
|
|
|
|
|
|
reshape_op = P.Reshape()
|
|
|
|
|
|
|
|
dtype = P.DType()
|
|
|
|
|
|
|
|
topk = P.TopK()
|
|
|
|
|
|
|
|
neg = P.Neg()
|
|
|
|
|
|
|
|
tranpose = P.Transpose()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def bprop(input_x, out, dout):
|
|
|
|
|
|
|
|
x_shape = input_x.shape
|
|
|
|
|
|
|
|
k = x_shape[axis]
|
|
|
|
|
|
|
|
rank = F.rank(input_x)
|
|
|
|
|
|
|
|
dvalue = dout[0]
|
|
|
|
|
|
|
|
if not descending:
|
|
|
|
|
|
|
|
input_x = neg(input_x)
|
|
|
|
|
|
|
|
dvalue = neg(dvalue)
|
|
|
|
|
|
|
|
if axis == -1 or (axis + 1) == rank:
|
|
|
|
|
|
|
|
transposition = None
|
|
|
|
|
|
|
|
top_k_input = input_x
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
transposition = _get_transposition(axis, rank)
|
|
|
|
|
|
|
|
top_k_input = tranpose(input_x, transposition)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
_, indices = topk(top_k_input, k)
|
|
|
|
|
|
|
|
ind_shape = indices.shape
|
|
|
|
|
|
|
|
top_k_input_shape = top_k_input.shape
|
|
|
|
|
|
|
|
in_lastdim = top_k_input_shape[-1]
|
|
|
|
|
|
|
|
ind_lastdim = ind_shape[-1]
|
|
|
|
|
|
|
|
ind_2d = reshape_op(indices, (-1, ind_lastdim))
|
|
|
|
|
|
|
|
outer_dim = ind_2d.shape[0]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# [0, outterdim, 2*outerdim, ..., (k-1)*outerdim]
|
|
|
|
|
|
|
|
indices_dtype = dtype(indices)
|
|
|
|
|
|
|
|
range_flatten_index = _range_op(0, outer_dim * in_lastdim, in_lastdim, indices_dtype)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# expand_dims to (k, 1), then broadcast
|
|
|
|
|
|
|
|
ind = reshape_op(ind_2d + expand_dims(range_flatten_index, -1), (-1,))
|
|
|
|
|
|
|
|
x_shape_1d = _get_1d_shape(top_k_input_shape)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if transposition is not None:
|
|
|
|
|
|
|
|
dvalue = tranpose(dvalue, invert_permutation(transposition))
|
|
|
|
|
|
|
|
out_grad = reshape_op(
|
|
|
|
|
|
|
|
scatter(expand_dims(ind, -1), reshape_op(dvalue, (-1,)), x_shape_1d), top_k_input_shape)
|
|
|
|
|
|
|
|
dx = tranpose(out_grad, invert_permutation(transposition))
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
dx = reshape_op(scatter(expand_dims(ind, -1), reshape_op(dvalue, (-1,)), x_shape_1d), top_k_input_shape)
|
|
|
|
|
|
|
|
if not descending:
|
|
|
|
|
|
|
|
dx = neg(dx)
|
|
|
|
|
|
|
|
return (dx,)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return bprop
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@bprop_getters.register(P.Identity)
|
|
|
|
@bprop_getters.register(P.Identity)
|
|
|
|
def get_bprop_identity(self):
|
|
|
|
def get_bprop_identity(self):
|
|
|
|
"""Generate bprop for Identity"""
|
|
|
|
"""Generate bprop for Identity"""
|
|
|
@ -475,6 +558,7 @@ def get_bprop_range(self):
|
|
|
|
|
|
|
|
|
|
|
|
def bprop(x, out, dout):
|
|
|
|
def bprop(x, out, dout):
|
|
|
|
return (zeros_like(x),)
|
|
|
|
return (zeros_like(x),)
|
|
|
|
|
|
|
|
|
|
|
|
return bprop
|
|
|
|
return bprop
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -506,7 +590,7 @@ def get_bprop_reverse_v2(self):
|
|
|
|
dx = reverse_grad(dout)
|
|
|
|
dx = reverse_grad(dout)
|
|
|
|
return (dx,)
|
|
|
|
return (dx,)
|
|
|
|
|
|
|
|
|
|
|
|
return bprop
|
|
|
|
return bprop
|
|
|
|
|
|
|
|
|
|
|
|
@bprop_getters.register(P.Unpack)
|
|
|
|
@bprop_getters.register(P.Unpack)
|
|
|
|
def get_bprop_unpack(self):
|
|
|
|
def get_bprop_unpack(self):
|
|
|
|