From b9d7e4e6b491be06784e0656103837a91ef4b4f6 Mon Sep 17 00:00:00 2001 From: Ziyan Date: Fri, 3 Jul 2020 09:41:15 +0800 Subject: [PATCH] add uniform split in the bprop of concat --- mindspore/ops/_grad/grad_array_ops.py | 26 ++++++++++++++++++++++---- 1 file changed, 22 insertions(+), 4 deletions(-) diff --git a/mindspore/ops/_grad/grad_array_ops.py b/mindspore/ops/_grad/grad_array_ops.py index b88d739718..005fdbc895 100644 --- a/mindspore/ops/_grad/grad_array_ops.py +++ b/mindspore/ops/_grad/grad_array_ops.py @@ -220,19 +220,37 @@ def get_bprop_transpose(self): return bprop +@constexpr +def _concat_grad_uniform(input_shapes, input_nums): + """Helper function for bprop of Concat""" + is_uniform = True + for i in range(1, input_nums): + if input_shapes[i-1] != input_shapes[i]: + is_uniform = False + break + return is_uniform + @bprop_getters.register(P.Concat) def get_bprop_concat(self): """Generate bprop for Concat""" axis = self.axis + is_ascend = context.get_context('device_target') == "Ascend" def bprop(x, out, dout): dx = () out_offset = G.ConcatOffset(F.tuple_len(x), axis)(x) - for i in range(F.tuple_len(x)): - slice_out = P.Slice()(dout, out_offset[i], shape_op(x[i])) - dx = dx + (slice_out,) + input_nums = F.tuple_len(x) + input_shapes = () + for i in range(input_nums): + input_shapes = input_shapes + (shape_op(x[i]),) + is_uniform = _concat_grad_uniform(input_shapes, input_nums) + if is_uniform and is_ascend: + dx = P.Split(axis, input_nums)(dout) + else: + for i in range(input_nums): + slice_out = P.Slice()(dout, out_offset[i], input_shapes[i]) + dx = dx + (slice_out,) return (dx,) - return bprop