diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/elu.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/elu.cc index ff133f28fd..638e4cc915 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/elu.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/elu.cc @@ -28,12 +28,18 @@ namespace mindspore::kernel { int EluCPUKernel::Init() { elu_parameter_ = reinterpret_cast(opParameter); elu_parameter_->thread_num_ = thread_count_; + + if (!InferShapeDone()) { + return RET_OK; + } + return ReSize(); +} + +int EluCPUKernel::ReSize() { elu_parameter_->in_size_ = inputs_.front()->ElementsNum(); return RET_OK; } -int EluCPUKernel::ReSize() { return RET_OK; } - int EluCPUKernel::DoExcute(int task_id) { Elu(input_addr, output_addr, elu_parameter_, task_id); } int EluRun(int task_id, LiteParallelGroupEnv *penv, void *cdata) { @@ -47,6 +53,11 @@ int EluRun(int task_id, LiteParallelGroupEnv *penv, void *cdata) { } int EluCPUKernel::Run() { + auto prepare_ret = Prepare(); + if (prepare_ret != RET_OK) { + MS_LOG(ERROR) << "Prepare fail!ret: " << prepare_ret; + return prepare_ret; + } input_addr = reinterpret_cast(inputs_.front()->Data()); output_addr = reinterpret_cast(outputs_.front()->Data()); diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/embedding_lookup.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/embedding_lookup.cc index fe60dda8a2..d41590d06d 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/embedding_lookup.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/embedding_lookup.cc @@ -26,12 +26,16 @@ using mindspore::schema::PrimitiveType_EmbeddingLookup; namespace mindspore::kernel { int EmbeddingLookupCPUKernel::Init() { - if (context_->infer_shape_interrupt_ && !context_->running_) { - SetNeedReInit(); - return RET_OK; - } embedding_lookup_parameter_ = reinterpret_cast(opParameter); embedding_lookup_parameter_->thread_num = thread_count_; + + if (!InferShapeDone()) { + return RET_OK; + } + return ReSize(); +} + +int EmbeddingLookupCPUKernel::ReSize() { embedding_lookup_parameter_->ids_size_ = inputs_.back()->ElementsNum(); embedding_lookup_parameter_->layer_size_ = 1; @@ -45,18 +49,34 @@ int EmbeddingLookupCPUKernel::Init() { embedding_lookup_parameter_->layer_num_ += inputs_[i]->shape()[0]; } - input_addr_ = reinterpret_cast( - std::malloc(sizeof(float) * embedding_lookup_parameter_->layer_size_ * embedding_lookup_parameter_->layer_num_)); + if (input_addr_ != nullptr) { + free(input_addr_); + } + if (context_ != nullptr && context_->allocator != nullptr) { + input_addr_ = reinterpret_cast(context_->allocator->Malloc( + sizeof(float) * embedding_lookup_parameter_->layer_size_ * embedding_lookup_parameter_->layer_num_)); + } else { + input_addr_ = reinterpret_cast( + malloc(sizeof(float) * embedding_lookup_parameter_->layer_size_ * embedding_lookup_parameter_->layer_num_)); + } if (input_addr_ == nullptr) { - MS_LOG(ERROR) << "Create memory failed"; - return mindspore::lite::RET_MEMORY_FAILED; + MS_LOG(ERROR) << "Malloc buffer failed"; + return RET_ERROR; } - embedding_lookup_parameter_->is_regulated_ = - reinterpret_cast(std::malloc(sizeof(bool) * embedding_lookup_parameter_->layer_num_)); + if (embedding_lookup_parameter_->is_regulated_ != nullptr) { + free(embedding_lookup_parameter_->is_regulated_); + } + if (context_ != nullptr && context_->allocator != nullptr) { + embedding_lookup_parameter_->is_regulated_ = + reinterpret_cast(context_->allocator->Malloc(sizeof(bool) * embedding_lookup_parameter_->layer_num_)); + } else { + embedding_lookup_parameter_->is_regulated_ = + reinterpret_cast(malloc(sizeof(bool) * embedding_lookup_parameter_->layer_num_)); + } if (embedding_lookup_parameter_->is_regulated_ == nullptr) { - MS_LOG(ERROR) << "Create memory failed"; - return mindspore::lite::RET_MEMORY_FAILED; + MS_LOG(ERROR) << "Malloc buffer failed"; + return RET_ERROR; } for (int i = 0; i < embedding_lookup_parameter_->layer_num_; ++i) { @@ -66,8 +86,6 @@ int EmbeddingLookupCPUKernel::Init() { return RET_OK; } -int EmbeddingLookupCPUKernel::ReSize() { return RET_OK; } - int EmbeddingLookupCPUKernel::DoExcute(int task_id) { int error_code = EmbeddingLookup(input_addr_, ids_addr_, output_addr_, embedding_lookup_parameter_, task_id); if (error_code != RET_OK) { diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/embedding_lookup.h b/mindspore/lite/src/runtime/kernel/arm/fp32/embedding_lookup.h index fd9defd03b..5a80550e82 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/embedding_lookup.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/embedding_lookup.h @@ -28,7 +28,14 @@ class EmbeddingLookupCPUKernel : public LiteKernel { const std::vector &outputs, const lite::Context *ctx, const lite::Primitive *primitive) : LiteKernel(parameter, inputs, outputs, ctx, primitive), ctx_(ctx), thread_count_(ctx->thread_num_) {} - ~EmbeddingLookupCPUKernel() override{}; + ~EmbeddingLookupCPUKernel() override { + if (input_addr_ != nullptr) { + free(input_addr_); + } + if (embedding_lookup_parameter_->is_regulated_ != nullptr) { + free(embedding_lookup_parameter_->is_regulated_); + } + }; int Init() override; int ReSize() override; diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/elu.cc b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/elu.cc index 37798eaaed..5ae5b31d66 100644 --- a/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/elu.cc +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/elu.cc @@ -15,7 +15,6 @@ */ #include "src/runtime/kernel/arm/nnacl/fp32/elu.h" -#include #include "include/errorcode.h" #include "src/runtime/kernel/arm/nnacl/errorcode.h" #include "mindspore/core/utils/log_adapter.h" diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/elu.h b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/elu.h index 2b4c10e7b8..6a510c9d47 100644 --- a/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/elu.h +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/elu.h @@ -19,12 +19,12 @@ #include "src/runtime/kernel/arm/nnacl/op_base.h" -struct EluParameter { +typedef struct { OpParameter op_parameter_; float alpha_; int thread_num_; int in_size_; -}; +} EluParameter; int Elu(float *input_data, float *output_data, EluParameter *parameter, int task_id); diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/embedding_lookup.cc b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/embedding_lookup.cc index 964041fa3c..93a5438cb9 100644 --- a/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/embedding_lookup.cc +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/embedding_lookup.cc @@ -15,7 +15,6 @@ */ #include "src/runtime/kernel/arm/nnacl/fp32/embedding_lookup.h" -#include #include "include/errorcode.h" #include "src/runtime/kernel/arm/nnacl/errorcode.h" #include "mindspore/core/utils/log_adapter.h" diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/embedding_lookup.h b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/embedding_lookup.h index fa9f0ce5da..1a8087ad56 100644 --- a/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/embedding_lookup.h +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/embedding_lookup.h @@ -19,15 +19,15 @@ #include "src/runtime/kernel/arm/nnacl/op_base.h" -struct EmbeddingLookupParameter { - OpParameter op_parameter_; - bool *is_regulated_; - float max_norm_; - int ids_size_; - int layer_size_; - int layer_num_; - int thread_num; -}; +typedef struct { + OpParameter op_parameter_; + bool *is_regulated_; + float max_norm_; + int ids_size_; + int layer_size_; + int layer_num_; + int thread_num; +} EmbeddingLookupParameter; int EmbeddingLookup(float *input_data, int *ids, float *output_data, EmbeddingLookupParameter *parameter, int task_id);