diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/expandDims.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/expandDims.cc index cdb6191697..7239929a82 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/expandDims.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/expandDims.cc @@ -29,12 +29,10 @@ using mindspore::schema::PrimitiveType_ExpandDims; namespace mindspore::kernel { int ExpandDimsCPUKernel::Init() { - if (context_->infer_shape_interrupt_ && !context_->running_) { - set_need_reinit(); + if (!InferShapeDone()) { return RET_OK; } - int ret = ReSize(); - return ret; + return ReSize(); } int ExpandDimsCPUKernel::ReSize() { diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/fill.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/fill.cc index e3e103d6dd..e620b8b2c7 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/fill.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/fill.cc @@ -35,18 +35,19 @@ constexpr int kOutputNum = 1; } // namespace int FillCPUKernel::Init() { - if (context_->infer_shape_interrupt_ && !context_->running_) { - set_need_reinit(); + if (!InferShapeDone()) { return RET_OK; } + return ReSize(); +} + +int FillCPUKernel::ReSize() { data_size_ = out_tensors_.front()->ElementsNum(); thread_sz_count_ = MSMIN(thread_count_, data_size_); thread_sz_stride_ = UP_DIV(data_size_, thread_sz_count_); return RET_OK; } -int FillCPUKernel::ReSize() { return RET_OK; } - int FillCPUKernel::DoFill(int task_id) { int size = MSMIN(thread_sz_stride_, data_size_ - task_id * thread_sz_stride_); if (size <= 0) { diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/gather.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/gather.cc index add33c961e..46587db0f1 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/gather.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/gather.cc @@ -32,7 +32,10 @@ namespace mindspore::kernel { int GatherCPUKernel::Init() { axis_ = (reinterpret_cast(op_parameter_))->axis_; batchDims_ = (reinterpret_cast(op_parameter_))->batchDims_; - return RET_OK; + if (!InferShapeDone()) { + return RET_OK; + } + return ReSize(); } int GatherCPUKernel::ReSize() { return RET_OK; } diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/gatherNd.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/gatherNd.cc index e76585980e..1f882ea6d6 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/gatherNd.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/gatherNd.cc @@ -38,10 +38,17 @@ GatherNdCPUKernel::~GatherNdCPUKernel() { } int GatherNdCPUKernel::Init() { - if (context_->infer_shape_interrupt_ && !context_->running_) { - set_need_reinit(); + if (!InferShapeDone()) { return RET_OK; } + return ReSize(); +} + +int GatherNdCPUKernel::ReSize() { + if (in_offset_ != nullptr) { + free(in_offset_); + in_offset_ = nullptr; + } auto indices_tensor = in_tensors_.at(1); auto indices_shape = indices_tensor->shape(); int indices_rank = indices_shape.size(); @@ -59,16 +66,9 @@ int GatherNdCPUKernel::Init() { thread_sz_count_ = MSMIN(thread_count_, count_); thread_sz_stride_ = UP_DIV(count_, thread_sz_count_); - int ret = ReSize(); - return ret; -} -int GatherNdCPUKernel::ReSize() { auto in_shape = in_tensors_.front()->shape(); int in_rank = in_shape.size(); - auto indices_tensor = in_tensors_.at(1); - auto indices_shape = indices_tensor->shape(); - int indices_rank = indices_shape.size(); int idx_lastshape = indices_shape[indices_rank - 1]; auto indices_ptr = reinterpret_cast(indices_tensor->Data()); area_ = 1; diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/one_hot.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/one_hot.cc index ed9253626a..2c3325ad56 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/one_hot.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/one_hot.cc @@ -35,40 +35,49 @@ constexpr size_t kOutputNum = 1; } // namespace int OneHotCPUKernel::Init() { - if (context_->infer_shape_interrupt_ && !context_->running_) { - set_need_reinit(); - return RET_OK; - } // indices depth on_value off_value if (in_tensors_.size() != kInputNum || out_tensors_.size() != kOutputNum) { MS_LOG(ERROR) << "OneHot input size should be " << kInputNum << ", got " << in_tensors_.size() << ", output size should be" << kOutputNum << ", got " << out_tensors_.size(); return RET_ERROR; } + if (context_ == nullptr) { + MS_LOG(ERROR) << "OneHot context nullptr"; + return RET_NULL_PTR; + } + thread_num_ = context_->thread_num_; + + auto param = reinterpret_cast(op_parameter_); + if (param == nullptr) { + MS_LOG(ERROR) << "OneHot op_parameter_ nullptr"; + return RET_NULL_PTR; + } + axis_ = param->axis_; + + if (!InferShapeDone()) { + return RET_OK; + } + return ReSize(); +} +int OneHotCPUKernel::ReSize() { auto indices = in_tensors_.at(0); if (indices == nullptr) { MS_LOG(ERROR) << "OneHot inputs[0] indices nullptr"; return RET_NULL_PTR; } auto indices_shape = indices->shape(); + const int indices_rank = static_cast(indices_shape.size()); + if (axis_ < 0) { + axis_ += indices_rank + 1; + } + outer_size_ = 1; for (size_t i = 0; i < static_cast(axis_); i++) { outer_size_ *= indices_shape[i]; } inner_size_ = indices->ElementsNum() / outer_size_; - if (context_ == nullptr) { - MS_LOG(ERROR) << "OneHot context nullptr"; - return RET_NULL_PTR; - } - thread_num_ = context_->thread_num_; - - const int indices_rank = static_cast(in_tensors_.at(0)->shape().size()); - if (axis_ < 0) { - axis_ += indices_rank + 1; - } - return RET_OK; } diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/one_hot.h b/mindspore/lite/src/runtime/kernel/arm/fp32/one_hot.h index 748b473ad9..52bfdcff2f 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/one_hot.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/one_hot.h @@ -26,12 +26,12 @@ class OneHotCPUKernel : public LiteKernel { OneHotCPUKernel(OpParameter *parameter, const std::vector &inputs, const std::vector &outputs, const lite::Context *ctx, const lite::Primitive *primitive) - : LiteKernel(parameter, inputs, outputs, ctx, primitive), context_(ctx) {} + : LiteKernel(parameter, inputs, outputs, ctx, primitive) {} ~OneHotCPUKernel() override = default; int Init() override; - int ReSize() override { return 0; }; + int ReSize() override; int Run() override; int OneHotImpl(int task_id); @@ -39,7 +39,6 @@ class OneHotCPUKernel : public LiteKernel { int GetParams(); private: - const lite::Context *context_; int thread_num_; int axis_; int outer_size_; diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/pad.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/pad.cc index 3f8d9c43d4..c8c13a802c 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/pad.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/pad.cc @@ -36,16 +36,19 @@ constexpr int kOutputNum = 1; } // namespace int PadCPUKernel::Init() { - if (context_->infer_shape_interrupt_ && !context_->running_) { - set_need_reinit(); - return RET_OK; - } if (in_tensors_.size() != kInputNum || out_tensors_.size() != kOutputNum) { MS_LOG(ERROR) << "Pad input size should be " << kInputNum << ", got " << in_tensors_.size() << ", output size should be" << kOutputNum << ", got " << out_tensors_.size(); return RET_ERROR; } + if (!InferShapeDone()) { + return RET_OK; + } + return ReSize(); +} + +int PadCPUKernel::ReSize() { auto input = in_tensors_.at(0); auto output = out_tensors_.at(0); if (input == nullptr || output == nullptr) { diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/pad.h b/mindspore/lite/src/runtime/kernel/arm/fp32/pad.h index f2a598d339..f05a9fee54 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/pad.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/pad.h @@ -35,7 +35,7 @@ class PadCPUKernel : public LiteKernel { ~PadCPUKernel() {} int Init() override; - int ReSize() override { return 0; }; + int ReSize() override; int Run() override; int RunImpl(int task_id); diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/reduce.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/reduce.cc index 64c4a07253..45cbb28496 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/reduce.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/reduce.cc @@ -44,10 +44,7 @@ int ReduceCPUKernel::Init() { if (ret != RET_OK) { return ret; } - ret = MallocTmpBuffer(); - if (ret != RET_OK) { - return ret; - } + switch (mode_) { case static_cast(ReduceMode_ReduceSum): { reducer_ = ReduceSum; @@ -77,12 +74,15 @@ int ReduceCPUKernel::Init() { MS_LOG(ERROR) << "Reduce unsupported reduce mode: " << mode_; return RET_ERROR; } + if (!InferShapeDone()) { return RET_OK; } return ReSize(); } +int ReduceCPUKernel::ReSize() { return MallocTmpBuffer(); } + int ReduceCPUKernel::CallReduceUnit(int task_id) { auto ret = reducer_(outer_size_, inner_size_, axis_size_, src_data_, tmp_shape_.data(), dst_data_, task_id, context_->thread_num_); @@ -149,6 +149,14 @@ int ReduceCPUKernel::Run() { } int ReduceCPUKernel::MallocTmpBuffer() { + for (auto buffer : data_buffers_) { + if (buffer != nullptr) { + free(buffer); + buffer = nullptr; + } + } + data_buffers_.clear(); + auto input_shape = in_tensors_.at(0)->shape(); for (auto i = 0; i < num_axes_ - 1; i++) { int axis = axes_[i]; diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/reduce.h b/mindspore/lite/src/runtime/kernel/arm/fp32/reduce.h index 5b05b76598..c3f5bd8c8f 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/reduce.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/reduce.h @@ -48,15 +48,15 @@ class ReduceCPUKernel : public ReduceBaseCPUKernel { } int Init() override; - int ReSize() override { return 0; }; + int ReSize() override; int Run() override; int CallReduceUnit(int task_id); private: - Reducer reducer_; + Reducer reducer_ = nullptr; std::vector data_buffers_; - const float *src_data_; - float *dst_data_; + const float *src_data_ = nullptr; + float *dst_data_ = nullptr; private: int MallocTmpBuffer(); diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/reverse.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/reverse.cc index 1c93f8aabb..f5307743c3 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/reverse.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/reverse.cc @@ -38,6 +38,10 @@ int ReverseCPUKernel::Stride(int index) { } int ReverseCPUKernel::ReSize() { + data_size_ = in_tensors_.at(0)->ElementsNum(); + thread_sz_count_ = MSMIN(thread_count_, data_size_); + thread_sz_stride_ = UP_DIV(data_size_, thread_sz_count_); + auto *param = reinterpret_cast(op_parameter_); auto input_shape = in_tensors_[0]->shape(); if (param->num_axis_ > input_shape.size()) { @@ -89,13 +93,9 @@ int ReverseCPUKernel::ReSize() { } int ReverseCPUKernel::Init() { - if (context_->infer_shape_interrupt_ && !context_->running_) { - set_need_reinit(); + if (!InferShapeDone()) { return RET_OK; } - data_size_ = in_tensors_.at(0)->ElementsNum(); - thread_sz_count_ = MSMIN(thread_count_, data_size_); - thread_sz_stride_ = UP_DIV(data_size_, thread_sz_count_); int ret = ReSize(); return ret; }