diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic.cc index 7d07ddc45b..92517f03a9 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic.cc @@ -22,6 +22,7 @@ #include "src/runtime/kernel/arm/int8/add_int8.h" #include "src/runtime/kernel/arm/int8/mul_int8.h" #include "src/runtime/runtime_api.h" +#include "src/populate_parameter.h" using mindspore::kernel::KERNEL_ARCH::kCPU; using mindspore::lite::KernelRegistrar; @@ -40,6 +41,31 @@ int ArithmeticCPUKernel::Init() { return ReSize(); } +int ArithmeticCPUKernel::PreProcess() { + if (!InferShapeDone()) { + (const_cast(primitive_))->SetInferFlag(true); + auto ret = (const_cast(primitive_))->InferShape(in_tensors_, out_tensors_); + if (ret != 0) { + (const_cast(primitive_))->SetInferFlag(false); + MS_LOG(ERROR) << "InferShape fail!"; + return ret; + } + arithmeticParameter_ = reinterpret_cast(kernel::PopulateArithmetic(primitive_)); + ret = ReSize(); + if (ret != 0) { + MS_LOG(ERROR) << "ReSize fail!ret: " << ret; + return ret; + } + } + + auto outputs = this->out_tensors(); + for (auto *output : outputs) { + MS_ASSERT(output != nullptr); + output->MallocData(); + } + return RET_OK; +} + int ArithmeticCPUKernel::ReSize() { if (in_tensors_[0]->data_type() == kNumberTypeFloat32 || in_tensors_[0]->data_type() == kNumberTypeFloat16) { data_type_ = kDataTypeFloat; diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic.h b/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic.h index d11d13d2b1..51ebf5ab62 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic.h @@ -163,6 +163,7 @@ class ArithmeticCPUKernel : public LiteKernel { ~ArithmeticCPUKernel() override; int Init() override; + int PreProcess() override; int ReSize() override; int Run() override; int DoArithmetic(int task_id); diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/add_int8.cc b/mindspore/lite/src/runtime/kernel/arm/int8/add_int8.cc index 21c9f775d7..92cd9402e9 100644 --- a/mindspore/lite/src/runtime/kernel/arm/int8/add_int8.cc +++ b/mindspore/lite/src/runtime/kernel/arm/int8/add_int8.cc @@ -81,7 +81,7 @@ int QuantizedAddCPUKernel::Run() { input1_data_ = static_cast(in_tensors_.at(1)->MutableData()); output_data_ = static_cast(out_tensors_.at(0)->MutableData()); - elements_num_ = in_tensors_.at(0)->ElementsNum(); + elements_num_ = out_tensors_.at(0)->ElementsNum(); count_unit_ = thread_count_ > 1 ? UP_DIV(elements_num_, thread_count_) : elements_num_; if (in_tensors_.at(0)->ElementsNum() != in_tensors_.at(1)->ElementsNum()) { diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/mul_int8.cc b/mindspore/lite/src/runtime/kernel/arm/int8/mul_int8.cc index 80bccc5c49..69066c0938 100644 --- a/mindspore/lite/src/runtime/kernel/arm/int8/mul_int8.cc +++ b/mindspore/lite/src/runtime/kernel/arm/int8/mul_int8.cc @@ -106,7 +106,7 @@ int MulInt8CPUKernel::Run() { input1_data_ = static_cast(in_tensors_.at(1)->MutableData()); output_data_ = static_cast(out_tensors_.at(0)->MutableData()); - elements_num_ = in_tensors_.at(0)->ElementsNum(); + elements_num_ = out_tensors_.at(0)->ElementsNum(); count_unit_ = thread_count_ > 1 ? UP_DIV(elements_num_, thread_count_) : elements_num_; if (in_tensors_.at(0)->ElementsNum() != in_tensors_.at(1)->ElementsNum()) { input0_data_ = static_cast(ctx_->allocator->Malloc(out_tensors_.at(0)->Size())); diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_deconv_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_deconv_parser.cc index cb638f718a..4bb9867ff1 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_deconv_parser.cc +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_deconv_parser.cc @@ -87,8 +87,8 @@ STATUS OnnxDeConvParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::N MS_LOG(ERROR) << "dilations size " << onnx_node_attr.ints().size() << " is not 2"; return RET_ERROR; } - attr->dilateW = static_cast(onnx_node_attr.ints(0)); - attr->dilateH = static_cast(onnx_node_attr.ints(1)); + attr->dilateH = static_cast(onnx_node_attr.ints(0)); + attr->dilateW = static_cast(onnx_node_attr.ints(1)); } else if (onnx_node_attr.name() == "kernels") { if (onnx_node_attr.ints().size() != 2) { MS_LOG(ERROR) << "kernel_shape size " << onnx_node_attr.ints().size() << " is not 2"; @@ -101,8 +101,8 @@ STATUS OnnxDeConvParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::N MS_LOG(ERROR) << "kernel_shape size " << onnx_node_attr.ints().size() << " is not 2"; return RET_ERROR; } - attr->kernelW = static_cast(onnx_node_attr.ints(0)); - attr->kernelH = static_cast(onnx_node_attr.ints(1)); + attr->kernelH = static_cast(onnx_node_attr.ints(0)); + attr->kernelW = static_cast(onnx_node_attr.ints(1)); } else if (onnx_node_attr.name() == "auto_pad") { attr->padMode = GetOnnxPadMode(onnx_node_attr); } else if (onnx_node_attr.name() == "pads") { @@ -119,8 +119,8 @@ STATUS OnnxDeConvParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::N MS_LOG(ERROR) << "strides size " << onnx_node_attr.ints().size() << " is not 2"; return RET_ERROR; } - attr->strideW = static_cast(onnx_node_attr.ints(0)); - attr->strideH = static_cast(onnx_node_attr.ints(1)); + attr->strideH = static_cast(onnx_node_attr.ints(0)); + attr->strideW = static_cast(onnx_node_attr.ints(1)); } else if (onnx_node_attr.name() == "order") { if (onnx_node_attr.s() == "NHWC") { attr->format = schema::Format::Format_NHWC;