diff --git a/mindspore/ops/operations/array_ops.py b/mindspore/ops/operations/array_ops.py index 399972ddf4..aa0346f977 100644 --- a/mindspore/ops/operations/array_ops.py +++ b/mindspore/ops/operations/array_ops.py @@ -4044,11 +4044,21 @@ class BatchToSpaceND(PrimitiveWithInfer): class BroadcastTo(PrimitiveWithInfer): """ Broadcasts input tensor to a given shape. - Input shape can be broadcast to target shape if for each dimension pair they are either equal or input is one. + + Input shape can be broadcast to target shape if for each dimension pair they are either equal or input is one or + the target dimension is -1. In case of -1 in target shape, it will be replaced by the input shape's value + in that dimension. + When input shape is broadcast to target shape, it starts with the trailing dimensions. + Raises: + ValueError: Given a shape tuple, if it has several -1s; or if the -1 is in an invalid position + such as one that does not have a opposing dimension in an input tensor; of if the target and + input shapes are incompatiable. + Args: - shape (tuple): The target shape to broadcast. + shape (tuple): The target shape to broadcast. Can be fully specified, or have '-1's in one position + where it will be substituted by the input tensor's shape in that position, see example. Inputs: - **input_x** (Tensor) - The input tensor. @@ -4067,6 +4077,14 @@ class BroadcastTo(PrimitiveWithInfer): >>> print(output) [[1. 2. 3.] [1. 2. 3.]] + + >>> shape = (2, -1) + >>> input_x = Tensor(np.array([1, 2, 3]).astype(np.float32)) + >>> broadcast_to = ops.BroadcastTo(shape) + >>> output = broadcast_to(input_x) + >>> print(output) + [[1. 2. 3.] + [1. 2. 3.]] """ @prim_attr_register @@ -4074,20 +4092,39 @@ class BroadcastTo(PrimitiveWithInfer): """Initialize BroadcastTo""" validator.check_value_type("shape", shape, (tuple), self.name) validator.check("shape length", len(shape), "", 0, Rel.GT, self.name) - for i in shape: - validator.check_positive_int(i, "shape element", self.name) + for ix, i in enumerate(shape): + validator.check_value_type('target shape index -> ' + str(ix), i, [int], self.name) + validator.check("shape element", i, "shape element min limit", -1, Rel.GE, self.name) self.shape = shape + if -1 in self.shape: + undef_dims = self.shape.count(-1) + if undef_dims > 1: + raise ValueError(f'The shape can only has one -1 at most, but has {undef_dims}.') + self.dyn = True + else: + self.dyn = False def infer_shape(self, x_shape): validator.check("input_x shape length", len(x_shape), "target shape", len(self.shape), Rel.LE, self.name) - + target_shape = list(self.shape) + outer_dim_offset = len(target_shape) - len(x_shape) + if self.dyn: + for i, v in enumerate(target_shape): + if v == -1: + if i < outer_dim_offset: + raise ValueError(f" -1 in init shape is in an incompatible location" + f" with given input tensor, -1 index in init shape: {i}" + f" but -1 can only be in index {len(x_shape)} onwards for this input.") + target_shape[i] = x_shape[i - outer_dim_offset] reversed_x_shape = tuple(reversed(x_shape)) - reversed_target = tuple(reversed(self.shape)) + reversed_target = tuple(reversed(target_shape)) for i, v in enumerate(reversed_x_shape): if v not in (reversed_target[i], 1): raise ValueError(f"Not supported shapes for broadcast, " - f"x_shape: {tuple(x_shape)}, target shape {self.shape}.") - return self.shape + f"x_shape: {tuple(x_shape)}, target shape {target_shape}.") + self.shape = tuple(target_shape) + self.add_prim_attr('shape', self.shape) + return target_shape def infer_dtype(self, x_dtype): validator.check_subclass("input_x", x_dtype, mstype.tensor, self.name) diff --git a/tests/st/ops/gpu/test_broadcast_to_ops.py b/tests/st/ops/gpu/test_broadcast_to_ops.py index 137f271519..86628cd2c4 100644 --- a/tests/st/ops/gpu/test_broadcast_to_ops.py +++ b/tests/st/ops/gpu/test_broadcast_to_ops.py @@ -27,9 +27,8 @@ from mindspore.ops import operations as P def test_broadcast(): context.set_context(mode=context.GRAPH_MODE, device_target='GPU') - x_np = np.random.rand(3, 1, 5, 1).astype(np.float32) shape = (3, 4, 5, 6) - + x_np = np.random.rand(3, 1, 5, 1).astype(np.float32) output = P.BroadcastTo(shape)(Tensor(x_np)) expect = np.broadcast_to(x_np, shape) assert np.allclose(output.asnumpy(), expect) @@ -39,8 +38,52 @@ def test_broadcast(): expect = np.broadcast_to(x1_np, shape) assert np.allclose(output.asnumpy(), expect) - x1_np = np.random.rand(4, 5).astype(np.float32) shape = (2, 3, 4, 5) + x1_np = np.random.rand(4, 5).astype(np.float32) output = P.BroadcastTo(shape)(Tensor(x1_np)) expect = np.broadcast_to(x1_np, shape) assert np.allclose(output.asnumpy(), expect) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_broadcast_dyn_init(): + """ + Test running the op with -1's in the init shape to support varied inputs. + """ + context.set_context(mode=context.GRAPH_MODE, device_target='GPU') + + ms_shape = (-1, 4, 5, 6) + np_shape = (3, 4, 5, 6) + x_np = np.random.rand(3, 1, 5, 1).astype(np.float32) + output = P.BroadcastTo(ms_shape)(Tensor(x_np)) + expect = np.broadcast_to(x_np, np_shape) + assert np.allclose(output.asnumpy(), expect) + + x1_np = np.random.rand(3, 1, 5, 1).astype(np.float16) + output = P.BroadcastTo(ms_shape)(Tensor(x1_np)) + expect = np.broadcast_to(x1_np, np_shape) + assert np.allclose(output.asnumpy(), expect) + + ms_shape = (2, 3, -1, 5) + np_shape = (2, 3, 4, 5) + x1_np = np.random.rand(4, 5).astype(np.float32) + output = P.BroadcastTo(ms_shape)(Tensor(x1_np)) + expect = np.broadcast_to(x1_np, np_shape) + assert np.allclose(output.asnumpy(), expect) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_broadcast_dyn_invalid_init(): + """ + Test running the op with -1's in the init shape in incorrect positions. + Expected to fail. + """ + context.set_context(mode=context.GRAPH_MODE, device_target='GPU') + ms_shape = (2, -1, 4, 5) + x_np = np.random.rand(4, 5).astype(np.float32) + with pytest.raises(ValueError): + P.BroadcastTo(ms_shape)(Tensor(x_np))