!2834 Enable Split in the bprop of Concat

Merge pull request !2834 from gziyan/add_uniform_split_for_concat
pull/2834/MERGE
mindspore-ci-bot 5 years ago committed by Gitee
commit fc24096baf

@ -220,19 +220,37 @@ def get_bprop_transpose(self):
return bprop
@constexpr
def _concat_grad_uniform(input_shapes, input_nums):
"""Helper function for bprop of Concat"""
is_uniform = True
for i in range(1, input_nums):
if input_shapes[i-1] != input_shapes[i]:
is_uniform = False
break
return is_uniform
@bprop_getters.register(P.Concat)
def get_bprop_concat(self):
"""Generate bprop for Concat"""
axis = self.axis
is_ascend = context.get_context('device_target') == "Ascend"
def bprop(x, out, dout):
dx = ()
out_offset = G.ConcatOffset(F.tuple_len(x), axis)(x)
for i in range(F.tuple_len(x)):
slice_out = P.Slice()(dout, out_offset[i], shape_op(x[i]))
input_nums = F.tuple_len(x)
input_shapes = ()
for i in range(input_nums):
input_shapes = input_shapes + (shape_op(x[i]),)
is_uniform = _concat_grad_uniform(input_shapes, input_nums)
if is_uniform and is_ascend:
dx = P.Split(axis, input_nums)(dout)
else:
for i in range(input_nums):
slice_out = P.Slice()(dout, out_offset[i], input_shapes[i])
dx = dx + (slice_out,)
return (dx,)
return bprop

Loading…
Cancel
Save