From 6cead43bdf0140fc64137e9437498f4d74e30cfe Mon Sep 17 00:00:00 2001 From: Peilin Wang Date: Sat, 20 Mar 2021 01:42:40 -0400 Subject: [PATCH] add general -1 dim behavior for BroadcastTo op --- mindspore/ops/operations/array_ops.py | 58 +++++++++++------------ tests/st/ops/gpu/test_broadcast_to_ops.py | 12 +++-- 2 files changed, 36 insertions(+), 34 deletions(-) diff --git a/mindspore/ops/operations/array_ops.py b/mindspore/ops/operations/array_ops.py index a746af6d62..2d6c27cfca 100644 --- a/mindspore/ops/operations/array_ops.py +++ b/mindspore/ops/operations/array_ops.py @@ -4551,7 +4551,9 @@ class BroadcastTo(PrimitiveWithInfer): the target dimension is -1. In case of -1 in target shape, it will be replaced by the input shape's value in that dimension. - When input shape is broadcast to target shape, it starts with the trailing dimensions. + When input shape is broadcast to target shape, it starts with the trailing + dimensions. If there is a -1 in the target shape, the -1 cannot be in a leading, + non-existing dimension. Args: shape (tuple): The target shape to broadcast. Can be fully specified, or have -1 in one position @@ -4566,9 +4568,8 @@ class BroadcastTo(PrimitiveWithInfer): Raises: TypeError: If `shape` is not a tuple. - ValueError: Given a shape tuple, if it has several -1; or if the -1 is in an invalid position - such as one that does not have a opposing dimension in an input tensor; or if the target and - input shapes are incompatible. + ValueError: if the target and input shapes are incompatible, or if a -1 in the + target shape is in an invalid location. Supported Platforms: ``Ascend`` ``GPU`` @@ -4582,13 +4583,13 @@ class BroadcastTo(PrimitiveWithInfer): [[1. 2. 3.] [1. 2. 3.]] - >>> shape = (2, -1) - >>> input_x = Tensor(np.array([1, 2, 3]).astype(np.float32)) + >>> shape = (-1, 2) + >>> input_x = Tensor(np.array([[1], [2]]).astype(np.float32)) >>> broadcast_to = ops.BroadcastTo(shape) >>> output = broadcast_to(input_x) >>> print(output) - [[1. 2. 3.] - [1. 2. 3.]] + [[1. 1.] + [2. 2.]] """ @prim_attr_register @@ -4600,35 +4601,30 @@ class BroadcastTo(PrimitiveWithInfer): validator.check_value_type('target shape index -> ' + str(ix), i, [int], self.name) validator.check("shape element", i, "shape element min limit", -1, Rel.GE, self.name) self.shape = shape - if -1 in self.shape: - undef_dims = self.shape.count(-1) - if undef_dims > 1: - raise ValueError(f'The shape can only has one -1 at most, but has {undef_dims}.') - self.dyn = True - else: - self.dyn = False def infer_shape(self, x_shape): validator.check("input_x shape length", len(x_shape), "target shape", len(self.shape), Rel.LE, self.name) - target_shape = list(self.shape) - outer_dim_offset = len(target_shape) - len(x_shape) - if self.dyn: - for i, v in enumerate(target_shape): - if v == -1: - if i < outer_dim_offset: - raise ValueError(f" -1 in init shape is in an incompatible location" - f" with given input tensor, -1 index in init shape: {i}" - f" but -1 can only be in index {len(x_shape)} onwards for this input.") - target_shape[i] = x_shape[i - outer_dim_offset] + reversed_x_shape = tuple(reversed(x_shape)) - reversed_target = tuple(reversed(target_shape)) + reversed_filtered_target = [] + for i, v in enumerate(tuple(reversed(self.shape))): + if v == -1: + if i >= len(reversed_x_shape): + raise ValueError("-1 is not valid in a leading, non-existing dimension") + + reversed_filtered_target.append(reversed_x_shape[i]) + else: + reversed_filtered_target.append(v) + + self.shape = tuple(reversed(reversed_filtered_target)) + self.add_prim_attr('shape', self.shape) + for i, v in enumerate(reversed_x_shape): - if v not in (reversed_target[i], 1): + if v not in (reversed_filtered_target[i], 1): raise ValueError(f"Not supported shapes for broadcast, " - f"x_shape: {tuple(x_shape)}, target shape {target_shape}.") - self.shape = tuple(target_shape) - self.add_prim_attr('shape', self.shape) - return target_shape + f"x_shape: {tuple(x_shape)}, target shape {self.shape}.") + + return self.shape def infer_dtype(self, x_dtype): validator.check_subclass("input_x", x_dtype, mstype.tensor, self.name) diff --git a/tests/st/ops/gpu/test_broadcast_to_ops.py b/tests/st/ops/gpu/test_broadcast_to_ops.py index 86628cd2c4..9e350c35df 100644 --- a/tests/st/ops/gpu/test_broadcast_to_ops.py +++ b/tests/st/ops/gpu/test_broadcast_to_ops.py @@ -1,4 +1,4 @@ -# Copyright 2020 Huawei Technologies Co., Ltd +# Copyright 2020-2021 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -54,7 +54,7 @@ def test_broadcast_dyn_init(): """ context.set_context(mode=context.GRAPH_MODE, device_target='GPU') - ms_shape = (-1, 4, 5, 6) + ms_shape = (-1, -1, 5, 6) np_shape = (3, 4, 5, 6) x_np = np.random.rand(3, 1, 5, 1).astype(np.float32) output = P.BroadcastTo(ms_shape)(Tensor(x_np)) @@ -66,7 +66,7 @@ def test_broadcast_dyn_init(): expect = np.broadcast_to(x1_np, np_shape) assert np.allclose(output.asnumpy(), expect) - ms_shape = (2, 3, -1, 5) + ms_shape = (2, 3, -1, -1) np_shape = (2, 3, 4, 5) x1_np = np.random.rand(4, 5).astype(np.float32) output = P.BroadcastTo(ms_shape)(Tensor(x1_np)) @@ -87,3 +87,9 @@ def test_broadcast_dyn_invalid_init(): x_np = np.random.rand(4, 5).astype(np.float32) with pytest.raises(ValueError): P.BroadcastTo(ms_shape)(Tensor(x_np)) + + context.set_context(mode=context.GRAPH_MODE, device_target='GPU') + ms_shape = (-1, 1, -1, -1) + x_np = np.random.rand(4, 5).astype(np.float32) + with pytest.raises(ValueError): + P.BroadcastTo(ms_shape)(Tensor(x_np))