update BroadcastTo

typo in API

added init type check + fix doc

fixes

lint

added prim_attr setter in python API

revert back one

changed self.shape / add prim attr

fix
pull/9540/head
danishnxt 4 years ago
parent b793c7b291
commit 67c528386e

@ -4044,11 +4044,21 @@ class BatchToSpaceND(PrimitiveWithInfer):
class BroadcastTo(PrimitiveWithInfer):
"""
Broadcasts input tensor to a given shape.
Input shape can be broadcast to target shape if for each dimension pair they are either equal or input is one.
Input shape can be broadcast to target shape if for each dimension pair they are either equal or input is one or
the target dimension is -1. In case of -1 in target shape, it will be replaced by the input shape's value
in that dimension.
When input shape is broadcast to target shape, it starts with the trailing dimensions.
Raises:
ValueError: Given a shape tuple, if it has several -1s; or if the -1 is in an invalid position
such as one that does not have a opposing dimension in an input tensor; of if the target and
input shapes are incompatiable.
Args:
shape (tuple): The target shape to broadcast.
shape (tuple): The target shape to broadcast. Can be fully specified, or have '-1's in one position
where it will be substituted by the input tensor's shape in that position, see example.
Inputs:
- **input_x** (Tensor) - The input tensor.
@ -4067,6 +4077,14 @@ class BroadcastTo(PrimitiveWithInfer):
>>> print(output)
[[1. 2. 3.]
[1. 2. 3.]]
>>> shape = (2, -1)
>>> input_x = Tensor(np.array([1, 2, 3]).astype(np.float32))
>>> broadcast_to = ops.BroadcastTo(shape)
>>> output = broadcast_to(input_x)
>>> print(output)
[[1. 2. 3.]
[1. 2. 3.]]
"""
@prim_attr_register
@ -4074,20 +4092,39 @@ class BroadcastTo(PrimitiveWithInfer):
"""Initialize BroadcastTo"""
validator.check_value_type("shape", shape, (tuple), self.name)
validator.check("shape length", len(shape), "", 0, Rel.GT, self.name)
for i in shape:
validator.check_positive_int(i, "shape element", self.name)
for ix, i in enumerate(shape):
validator.check_value_type('target shape index -> ' + str(ix), i, [int], self.name)
validator.check("shape element", i, "shape element min limit", -1, Rel.GE, self.name)
self.shape = shape
if -1 in self.shape:
undef_dims = self.shape.count(-1)
if undef_dims > 1:
raise ValueError(f'The shape can only has one -1 at most, but has {undef_dims}.')
self.dyn = True
else:
self.dyn = False
def infer_shape(self, x_shape):
validator.check("input_x shape length", len(x_shape), "target shape", len(self.shape), Rel.LE, self.name)
target_shape = list(self.shape)
outer_dim_offset = len(target_shape) - len(x_shape)
if self.dyn:
for i, v in enumerate(target_shape):
if v == -1:
if i < outer_dim_offset:
raise ValueError(f" -1 in init shape is in an incompatible location"
f" with given input tensor, -1 index in init shape: {i}"
f" but -1 can only be in index {len(x_shape)} onwards for this input.")
target_shape[i] = x_shape[i - outer_dim_offset]
reversed_x_shape = tuple(reversed(x_shape))
reversed_target = tuple(reversed(self.shape))
reversed_target = tuple(reversed(target_shape))
for i, v in enumerate(reversed_x_shape):
if v not in (reversed_target[i], 1):
raise ValueError(f"Not supported shapes for broadcast, "
f"x_shape: {tuple(x_shape)}, target shape {self.shape}.")
return self.shape
f"x_shape: {tuple(x_shape)}, target shape {target_shape}.")
self.shape = tuple(target_shape)
self.add_prim_attr('shape', self.shape)
return target_shape
def infer_dtype(self, x_dtype):
validator.check_subclass("input_x", x_dtype, mstype.tensor, self.name)

@ -27,9 +27,8 @@ from mindspore.ops import operations as P
def test_broadcast():
context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
x_np = np.random.rand(3, 1, 5, 1).astype(np.float32)
shape = (3, 4, 5, 6)
x_np = np.random.rand(3, 1, 5, 1).astype(np.float32)
output = P.BroadcastTo(shape)(Tensor(x_np))
expect = np.broadcast_to(x_np, shape)
assert np.allclose(output.asnumpy(), expect)
@ -39,8 +38,52 @@ def test_broadcast():
expect = np.broadcast_to(x1_np, shape)
assert np.allclose(output.asnumpy(), expect)
x1_np = np.random.rand(4, 5).astype(np.float32)
shape = (2, 3, 4, 5)
x1_np = np.random.rand(4, 5).astype(np.float32)
output = P.BroadcastTo(shape)(Tensor(x1_np))
expect = np.broadcast_to(x1_np, shape)
assert np.allclose(output.asnumpy(), expect)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_broadcast_dyn_init():
"""
Test running the op with -1's in the init shape to support varied inputs.
"""
context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
ms_shape = (-1, 4, 5, 6)
np_shape = (3, 4, 5, 6)
x_np = np.random.rand(3, 1, 5, 1).astype(np.float32)
output = P.BroadcastTo(ms_shape)(Tensor(x_np))
expect = np.broadcast_to(x_np, np_shape)
assert np.allclose(output.asnumpy(), expect)
x1_np = np.random.rand(3, 1, 5, 1).astype(np.float16)
output = P.BroadcastTo(ms_shape)(Tensor(x1_np))
expect = np.broadcast_to(x1_np, np_shape)
assert np.allclose(output.asnumpy(), expect)
ms_shape = (2, 3, -1, 5)
np_shape = (2, 3, 4, 5)
x1_np = np.random.rand(4, 5).astype(np.float32)
output = P.BroadcastTo(ms_shape)(Tensor(x1_np))
expect = np.broadcast_to(x1_np, np_shape)
assert np.allclose(output.asnumpy(), expect)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_broadcast_dyn_invalid_init():
"""
Test running the op with -1's in the init shape in incorrect positions.
Expected to fail.
"""
context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
ms_shape = (2, -1, 4, 5)
x_np = np.random.rand(4, 5).astype(np.float32)
with pytest.raises(ValueError):
P.BroadcastTo(ms_shape)(Tensor(x_np))

Loading…
Cancel
Save