From afcb3e9b45acea82bd206aae861d1eea82aac008 Mon Sep 17 00:00:00 2001 From: zhanyuan Date: Mon, 7 Sep 2020 10:41:16 +0800 Subject: [PATCH] Support exponent tensor broadcast for power op --- mindspore/lite/src/ops/power.cc | 4 +++- mindspore/lite/src/runtime/kernel/arm/fp32/power.cc | 4 ++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/mindspore/lite/src/ops/power.cc b/mindspore/lite/src/ops/power.cc index d812ed1d1f..aca5e1c8c6 100644 --- a/mindspore/lite/src/ops/power.cc +++ b/mindspore/lite/src/ops/power.cc @@ -64,7 +64,9 @@ int Power::InferShape(std::vector inputs, std::vectorshape() != x_tensor->shape() || exp_tensor->data_type() != x_tensor->data_type()) { + if ((exp_tensor->shape().size() > 1 && exp_tensor->shape() != x_tensor->shape()) || + (exp_tensor->shape().size() == 1 && exp_tensor->shape()[0] != 1) || + exp_tensor->data_type() != x_tensor->data_type()) { MS_LOG(ERROR) << "Power inputs shape or type is not equal!"; return RET_INPUT_TENSOR_ERROR; } diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/power.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/power.cc index 4b1cef6fdb..089cb8ba59 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/power.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/power.cc @@ -64,11 +64,11 @@ int PowerCPUKernel::RunImpl(int task_id) { bool broadcast = true; if (in_tensors_.size() == 2) { exp_addr = reinterpret_cast(in_tensors_[1]->Data()); - broadcast = false; + broadcast = in_tensors_[0]->shape() == in_tensors_[1]->shape() ? false : true; } float *cur_exp = nullptr; if (broadcast) { - cur_exp = &power_; + cur_exp = in_tensors_.size() == 2 ? exp_addr : &power_; } else { cur_exp = exp_addr + stride * task_id; }