|
|
|
@ -351,15 +351,30 @@ def _generate_inverse_index(x_shape, axis):
|
|
|
|
|
return perm
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@constexpr
|
|
|
|
|
def _regenerate_output_shape(x_shp, ind_shp, axis):
|
|
|
|
|
rank = len(x_shp)
|
|
|
|
|
if axis < 0:
|
|
|
|
|
axis += rank
|
|
|
|
|
out_shape = x_shp[:axis] + ind_shp + x_shp[axis + 1:]
|
|
|
|
|
return out_shape
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@bprop_getters.register(P.GatherV2)
|
|
|
|
|
def get_bprop_gather_v2(self):
|
|
|
|
|
"""Generate bprop for GatherV2"""
|
|
|
|
|
|
|
|
|
|
def bprop(x, indices, axis, out, dout):
|
|
|
|
|
orig_indices = indices
|
|
|
|
|
if F.rank(dout) == 0:
|
|
|
|
|
dout = P.ExpandDims()(dout, -1)
|
|
|
|
|
if F.rank(indices) == 0:
|
|
|
|
|
indices = P.ExpandDims()(indices, -1)
|
|
|
|
|
x_shp = shape_op(x)
|
|
|
|
|
ind_shp = shape_op(indices)
|
|
|
|
|
out_shp = _regenerate_output_shape(x_shp, ind_shp, axis)
|
|
|
|
|
dout = reshape(dout, out_shp)
|
|
|
|
|
|
|
|
|
|
x_shp = shape_op(x)
|
|
|
|
|
out_shp = shape_op(dout)
|
|
|
|
|
ind_shp = shape_op(indices)
|
|
|
|
@ -373,7 +388,7 @@ def get_bprop_gather_v2(self):
|
|
|
|
|
# Example: out_shape:(3,2,3) axis 2 -> (1,2,0)
|
|
|
|
|
perm_2 = _generate_inverse_index(x_shp, axis)
|
|
|
|
|
params_grad = transpose(params_grad, perm_2)
|
|
|
|
|
return params_grad, zeros_like(indices), zeros_like(axis)
|
|
|
|
|
return params_grad, zeros_like(orig_indices), zeros_like(axis)
|
|
|
|
|
|
|
|
|
|
return bprop
|
|
|
|
|
|
|
|
|
|