!10420 [ME]SparseGatherV2 throw exception when grad with 1D tensor input

From: @chenfei52
Reviewed-by: @ginfung,@zh_qh
Signed-off-by: @zh_qh
pull/10420/MERGE
mindspore-ci-bot 5 years ago committed by Gitee
commit f7eda1118c

@ -433,7 +433,10 @@ def get_bprop_sparse_gather_v2(self):
x_shp = shape_op(x)
if axis == 0:
indices_size = (size_op(indices),)
x_tail_shp = x_shp[1:]
if len(x_shp) <= 1:
x_tail_shp = ()
else:
x_tail_shp = x_shp[1:]
values_shape = indices_size + x_tail_shp
values = reshape(dout, values_shape)
indices_new = reshape(indices, indices_size)

Loading…
Cancel
Save