|
|
|
@ -51,11 +51,12 @@ class BroadcastToGpuKernel : public GpuKernel {
|
|
|
|
|
MS_LOG(EXCEPTION) << "BroadcastTo operation not support dim greater than 4";
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
for (int i = input_shapes.size() - 1; i >= 0; i--) {
|
|
|
|
|
input_shape_[i] = input_shapes[i];
|
|
|
|
|
size_t offset = output_shapes.size() - input_shapes.size();
|
|
|
|
|
for (size_t i = 0; i < input_shapes.size(); i++) {
|
|
|
|
|
input_shape_[i + offset] = input_shapes[i];
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
for (int j = output_shapes.size() - 1; j >= 0; j--) {
|
|
|
|
|
for (size_t j = 0; j < output_shapes.size(); j++) {
|
|
|
|
|
output_shape_[j] = output_shapes[j];
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|