From eb97093f8bcdaeeb3d32afe53f9daf9816db161a Mon Sep 17 00:00:00 2001 From: jinyaohui Date: Thu, 25 Feb 2021 18:36:14 +0800 Subject: [PATCH] add pack bprop --- mindspore/ops/_grad/grad_array_ops.py | 33 +-------------------------- 1 file changed, 1 insertion(+), 32 deletions(-) diff --git a/mindspore/ops/_grad/grad_array_ops.py b/mindspore/ops/_grad/grad_array_ops.py index acb5537b9c..f0787d25b5 100644 --- a/mindspore/ops/_grad/grad_array_ops.py +++ b/mindspore/ops/_grad/grad_array_ops.py @@ -386,38 +386,6 @@ def _regenerate_output_shape(x_shp, ind_shp, axis): @bprop_getters.register(P.Gather) -def get_bprop_gather(self): - """Generate bprop for GatherV2""" - - def bprop(x, indices, axis, out, dout): - orig_indices = indices - if F.rank(dout) == 0: - dout = P.ExpandDims()(dout, -1) - if F.rank(indices) == 0: - indices = P.ExpandDims()(indices, -1) - x_shp = shape_op(x) - ind_shp = shape_op(indices) - out_shp = _regenerate_output_shape(x_shp, ind_shp, axis) - dout = reshape(dout, out_shp) - - x_shp = shape_op(x) - out_shp = shape_op(dout) - ind_shp = shape_op(indices) - # Example: out_shape:(3,2,3) axis 1 -> (1,0,2) - perm_1 = _generate_shape_index(out_shp, ind_shp, axis) - values_transpose = transpose(dout, perm_1) - if -1 in shape_op(x): - params_grad = unsorted_segment_sum(values_transpose, indices, dyn_shape_op(x)[axis]) - else: - params_grad = unsorted_segment_sum(values_transpose, indices, shape_op(x)[axis]) - # Example: out_shape:(3,2,3) axis 2 -> (1,2,0) - perm_2 = _generate_inverse_index(x_shp, axis) - params_grad = transpose(params_grad, perm_2) - return params_grad, zeros_like(orig_indices), zeros_like(axis) - - return bprop - - @bprop_getters.register(P.GatherV2) def get_bprop_gather_v2(self): """Generate bprop for GatherV2""" @@ -601,6 +569,7 @@ def get_bprop_range(self): return bprop +@bprop_getters.register(P.Pack) @bprop_getters.register(P.Stack) def get_bprop_stack(self): """Generate bprop for Stack"""