|
|
|
@ -220,19 +220,37 @@ def get_bprop_transpose(self):
|
|
|
|
|
return bprop
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@constexpr
|
|
|
|
|
def _concat_grad_uniform(input_shapes, input_nums):
|
|
|
|
|
"""Helper function for bprop of Concat"""
|
|
|
|
|
is_uniform = True
|
|
|
|
|
for i in range(1, input_nums):
|
|
|
|
|
if input_shapes[i-1] != input_shapes[i]:
|
|
|
|
|
is_uniform = False
|
|
|
|
|
break
|
|
|
|
|
return is_uniform
|
|
|
|
|
|
|
|
|
|
@bprop_getters.register(P.Concat)
|
|
|
|
|
def get_bprop_concat(self):
|
|
|
|
|
"""Generate bprop for Concat"""
|
|
|
|
|
axis = self.axis
|
|
|
|
|
is_ascend = context.get_context('device_target') == "Ascend"
|
|
|
|
|
|
|
|
|
|
def bprop(x, out, dout):
|
|
|
|
|
dx = ()
|
|
|
|
|
out_offset = G.ConcatOffset(F.tuple_len(x), axis)(x)
|
|
|
|
|
for i in range(F.tuple_len(x)):
|
|
|
|
|
slice_out = P.Slice()(dout, out_offset[i], shape_op(x[i]))
|
|
|
|
|
dx = dx + (slice_out,)
|
|
|
|
|
input_nums = F.tuple_len(x)
|
|
|
|
|
input_shapes = ()
|
|
|
|
|
for i in range(input_nums):
|
|
|
|
|
input_shapes = input_shapes + (shape_op(x[i]),)
|
|
|
|
|
is_uniform = _concat_grad_uniform(input_shapes, input_nums)
|
|
|
|
|
if is_uniform and is_ascend:
|
|
|
|
|
dx = P.Split(axis, input_nums)(dout)
|
|
|
|
|
else:
|
|
|
|
|
for i in range(input_nums):
|
|
|
|
|
slice_out = P.Slice()(dout, out_offset[i], input_shapes[i])
|
|
|
|
|
dx = dx + (slice_out,)
|
|
|
|
|
return (dx,)
|
|
|
|
|
|
|
|
|
|
return bprop
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|