|
|
|
@ -4044,11 +4044,21 @@ class BatchToSpaceND(PrimitiveWithInfer):
|
|
|
|
|
class BroadcastTo(PrimitiveWithInfer):
|
|
|
|
|
"""
|
|
|
|
|
Broadcasts input tensor to a given shape.
|
|
|
|
|
Input shape can be broadcast to target shape if for each dimension pair they are either equal or input is one.
|
|
|
|
|
|
|
|
|
|
Input shape can be broadcast to target shape if for each dimension pair they are either equal or input is one or
|
|
|
|
|
the target dimension is -1. In case of -1 in target shape, it will be replaced by the input shape's value
|
|
|
|
|
in that dimension.
|
|
|
|
|
|
|
|
|
|
When input shape is broadcast to target shape, it starts with the trailing dimensions.
|
|
|
|
|
|
|
|
|
|
Raises:
|
|
|
|
|
ValueError: Given a shape tuple, if it has several -1s; or if the -1 is in an invalid position
|
|
|
|
|
such as one that does not have a opposing dimension in an input tensor; of if the target and
|
|
|
|
|
input shapes are incompatiable.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
shape (tuple): The target shape to broadcast.
|
|
|
|
|
shape (tuple): The target shape to broadcast. Can be fully specified, or have '-1's in one position
|
|
|
|
|
where it will be substituted by the input tensor's shape in that position, see example.
|
|
|
|
|
|
|
|
|
|
Inputs:
|
|
|
|
|
- **input_x** (Tensor) - The input tensor.
|
|
|
|
@ -4067,6 +4077,14 @@ class BroadcastTo(PrimitiveWithInfer):
|
|
|
|
|
>>> print(output)
|
|
|
|
|
[[1. 2. 3.]
|
|
|
|
|
[1. 2. 3.]]
|
|
|
|
|
|
|
|
|
|
>>> shape = (2, -1)
|
|
|
|
|
>>> input_x = Tensor(np.array([1, 2, 3]).astype(np.float32))
|
|
|
|
|
>>> broadcast_to = ops.BroadcastTo(shape)
|
|
|
|
|
>>> output = broadcast_to(input_x)
|
|
|
|
|
>>> print(output)
|
|
|
|
|
[[1. 2. 3.]
|
|
|
|
|
[1. 2. 3.]]
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
@prim_attr_register
|
|
|
|
@ -4074,20 +4092,39 @@ class BroadcastTo(PrimitiveWithInfer):
|
|
|
|
|
"""Initialize BroadcastTo"""
|
|
|
|
|
validator.check_value_type("shape", shape, (tuple), self.name)
|
|
|
|
|
validator.check("shape length", len(shape), "", 0, Rel.GT, self.name)
|
|
|
|
|
for i in shape:
|
|
|
|
|
validator.check_positive_int(i, "shape element", self.name)
|
|
|
|
|
for ix, i in enumerate(shape):
|
|
|
|
|
validator.check_value_type('target shape index -> ' + str(ix), i, [int], self.name)
|
|
|
|
|
validator.check("shape element", i, "shape element min limit", -1, Rel.GE, self.name)
|
|
|
|
|
self.shape = shape
|
|
|
|
|
if -1 in self.shape:
|
|
|
|
|
undef_dims = self.shape.count(-1)
|
|
|
|
|
if undef_dims > 1:
|
|
|
|
|
raise ValueError(f'The shape can only has one -1 at most, but has {undef_dims}.')
|
|
|
|
|
self.dyn = True
|
|
|
|
|
else:
|
|
|
|
|
self.dyn = False
|
|
|
|
|
|
|
|
|
|
def infer_shape(self, x_shape):
|
|
|
|
|
validator.check("input_x shape length", len(x_shape), "target shape", len(self.shape), Rel.LE, self.name)
|
|
|
|
|
|
|
|
|
|
target_shape = list(self.shape)
|
|
|
|
|
outer_dim_offset = len(target_shape) - len(x_shape)
|
|
|
|
|
if self.dyn:
|
|
|
|
|
for i, v in enumerate(target_shape):
|
|
|
|
|
if v == -1:
|
|
|
|
|
if i < outer_dim_offset:
|
|
|
|
|
raise ValueError(f" -1 in init shape is in an incompatible location"
|
|
|
|
|
f" with given input tensor, -1 index in init shape: {i}"
|
|
|
|
|
f" but -1 can only be in index {len(x_shape)} onwards for this input.")
|
|
|
|
|
target_shape[i] = x_shape[i - outer_dim_offset]
|
|
|
|
|
reversed_x_shape = tuple(reversed(x_shape))
|
|
|
|
|
reversed_target = tuple(reversed(self.shape))
|
|
|
|
|
reversed_target = tuple(reversed(target_shape))
|
|
|
|
|
for i, v in enumerate(reversed_x_shape):
|
|
|
|
|
if v not in (reversed_target[i], 1):
|
|
|
|
|
raise ValueError(f"Not supported shapes for broadcast, "
|
|
|
|
|
f"x_shape: {tuple(x_shape)}, target shape {self.shape}.")
|
|
|
|
|
return self.shape
|
|
|
|
|
f"x_shape: {tuple(x_shape)}, target shape {target_shape}.")
|
|
|
|
|
self.shape = tuple(target_shape)
|
|
|
|
|
self.add_prim_attr('shape', self.shape)
|
|
|
|
|
return target_shape
|
|
|
|
|
|
|
|
|
|
def infer_dtype(self, x_dtype):
|
|
|
|
|
validator.check_subclass("input_x", x_dtype, mstype.tensor, self.name)
|
|
|
|
|