solve broadcast two same shape bprop error

make unsupported shape error info explicit
pull/2356/head
zhaozhenlong 5 years ago
parent 9bc2ffde54
commit 5962c6efe9

@ -673,6 +673,10 @@ def get_bprop_broadcast_to(self):
def bprop(x, out, dout):
x_shape = shape_op(x)
dout_shape = shape_op(dout)
if x_shape == dout_shape:
return (dout,)
_, reduction_axes = broadcast_gradient_args(broadcast_shape, x_shape)
reduced_grad = reduce_keep_dim(dout, reduction_axes)
dx = reshape(reduced_grad, x_shape)

@ -2719,6 +2719,8 @@ class BatchToSpaceND(PrimitiveWithInfer):
class BroadcastTo(PrimitiveWithInfer):
"""
Broadcasts input tensor to a given shape.
Input shape can be broadcast to target shape if for each dimension pair they are either equal or input is one.
When input shape is broadcast to target shape, it starts with the trailing dimensions.
Args:
shape (tuple): The target shape to broadcast.
@ -2741,11 +2743,20 @@ class BroadcastTo(PrimitiveWithInfer):
def __init__(self, shape):
"""Init BroadcastTo"""
validator.check_value_type("shape", shape, (tuple), self.name)
validator.check("shape length", len(shape), "", 0, Rel.GT, self.name)
for i in shape:
validator.check_integer("shape element", i, 0, Rel.GT, self.name)
self.shape = shape
def infer_shape(self, x_shape):
validator.check("input_x shape length", len(x_shape), "target shape", len(self.shape), Rel.LE, self.name)
reversed_x_shape = tuple(reversed(x_shape))
reversed_target = tuple(reversed(self.shape))
for i, v in enumerate(reversed_x_shape):
if v not in (reversed_target[i], 1):
raise ValueError(f"Not supported shapes for broadcast, "
f"x_shape: {tuple(x_shape)}, target shape {self.shape}.")
return self.shape
def infer_dtype(self, x_dtype):

Loading…
Cancel
Save