!3655 gpu support BroadcastTo kernel

Merge pull request !3655 from chenweifeng/broadcast_to
pull/3655/MERGE
mindspore-ci-bot 5 years ago committed by Gitee
commit f1a39a0f72

@ -51,11 +51,12 @@ class BroadcastToGpuKernel : public GpuKernel {
MS_LOG(EXCEPTION) << "BroadcastTo operation not support dim greater than 4"; MS_LOG(EXCEPTION) << "BroadcastTo operation not support dim greater than 4";
} }
for (int i = input_shapes.size() - 1; i >= 0; i--) { size_t offset = output_shapes.size() - input_shapes.size();
input_shape_[i] = input_shapes[i]; for (size_t i = 0; i < input_shapes.size(); i++) {
input_shape_[i + offset] = input_shapes[i];
} }
for (int j = output_shapes.size() - 1; j >= 0; j--) { for (size_t j = 0; j < output_shapes.size(); j++) {
output_shape_[j] = output_shapes[j]; output_shape_[j] = output_shapes[j];
} }

@ -38,3 +38,9 @@ def test_broadcast():
output = P.BroadcastTo(shape)(Tensor(x1_np)) output = P.BroadcastTo(shape)(Tensor(x1_np))
expect = np.broadcast_to(x1_np, shape) expect = np.broadcast_to(x1_np, shape)
assert np.allclose(output.asnumpy(), expect) assert np.allclose(output.asnumpy(), expect)
x1_np = np.random.rand(4, 5).astype(np.float32)
shape = (2, 3, 4, 5)
output = P.BroadcastTo(shape)(Tensor(x1_np))
expect = np.broadcast_to(x1_np, shape)
assert np.allclose(output.asnumpy(), expect)

Loading…
Cancel
Save