!3711 fix topK multi dimention grad func

Merge pull request !3711 from fangzehua/topkgrad
pull/3711/MERGE
mindspore-ci-bot 5 years ago committed by Gitee
commit 57ce3e5dfc

@ -14,6 +14,7 @@
# ============================================================================ # ============================================================================
"""Define the grad rules of neural network related operations.""" """Define the grad rules of neural network related operations."""
import math
import numpy as np import numpy as np
from mindspore.ops import _selected_grad_ops as SG from mindspore.ops import _selected_grad_ops as SG
from mindspore.ops.primitive import constexpr from mindspore.ops.primitive import constexpr
@ -628,19 +629,62 @@ def get_bprop_onehot(self):
return bprop return bprop
@constexpr
def _range_op(start, limit, delta, dtype):
"""helper function for Grad TopK"""
range_op = inner.Range(float(start), float(limit), float(delta))
length_input = math.ceil((limit - start) / delta)
input_tensor = Tensor(list(range(length_input)), dtype)
range_out = range_op(input_tensor)
return range_out
@constexpr
def _get_1d_shape(in_shape):
"""helper function for Grad TopK"""
out_shape = 1
for i in in_shape:
out_shape *= i
return (out_shape,)
@bprop_getters.register(P.TopK) @bprop_getters.register(P.TopK)
def get_bprop_top_kv2(self): def get_bprop_top_kv2(self):
"""Grad definition for `TopK` operation.""" """Grad definition for `TopK` operation."""
scatter = P.ScatterNd() scatter = P.ScatterNd()
expand_dims = P.ExpandDims() expand_dims = P.ExpandDims()
shape_op = P.Shape() shape_op = P.Shape()
reshape_op = P.Reshape()
dtype = P.DType()
def bprop(input_x, k, out, dout): def bprop(input_x, k, out, dout):
# (n1, n2, ...., n_p), in_lastdim = n_p
in_shape = shape_op(input_x)
in_lastdim = in_shape[-1]
# (n_1, ... n_(p-1), k), ind_lastdim = k
indices = out[1] indices = out[1]
indices = expand_dims(indices, -1) ind_shape = shape_op(indices)
updates = dout[0] ind_lastdim = ind_shape[-1]
shapes = shape_op(input_x)
return scatter(indices, updates, shapes), zeros_like(k) # (n_1*n_2..*n_(p-1), k), outerdim = n_1*n_2..*n_(p-1)
ind_2d = reshape_op(indices, (-1, ind_lastdim))
outerdim = shape_op(ind_2d)[0]
# [0, outterdim, 2*outerdim, ..., (k-1)*outerdim]
indices_dtype = dtype(indices)
range_flatten_index = _range_op(0, outerdim * in_lastdim, in_lastdim, indices_dtype)
# expand_dims to (k, 1), then broadcast
ind = reshape_op(ind_2d + expand_dims(range_flatten_index, -1), (-1,))
in_shape_1d = _get_1d_shape(in_shape)
out_grad = reshape_op(
scatter(
expand_dims(ind, -1),
reshape_op(dout[0], (-1,)),
in_shape_1d),
in_shape)
return out_grad, zeros_like(k)
return bprop return bprop

Loading…
Cancel
Save