diff --git a/mindspore/lite/src/kernel_registry.cc b/mindspore/lite/src/kernel_registry.cc index e30c91544d..36c5a767a6 100644 --- a/mindspore/lite/src/kernel_registry.cc +++ b/mindspore/lite/src/kernel_registry.cc @@ -28,30 +28,8 @@ using mindspore::kernel::kCPU; using mindspore::kernel::KERNEL_ARCH; using mindspore::kernel::KernelCreator; using mindspore::kernel::KernelKey; -using mindspore::kernel::kKernelArch_MAX; -using mindspore::kernel::kKernelArch_MIN; -using mindspore::schema::PrimitiveType_MAX; -using mindspore::schema::PrimitiveType_MIN; namespace mindspore::lite { -KernelRegistry::KernelRegistry() { - device_type_length_ = kKernelArch_MAX - kKernelArch_MIN + 1; - data_type_length_ = kNumberTypeEnd - kNumberTypeBegin + 1; - op_type_length_ = PrimitiveType_MAX - PrimitiveType_MIN + 1; - // malloc an array contain creator functions of kernel - array_size_ = device_type_length_ * data_type_length_ * op_type_length_; - creator_arrays_ = (kernel::KernelCreator *)malloc(array_size_ * sizeof(kernel::KernelCreator)); - if (creator_arrays_ == nullptr) { - MS_LOG(ERROR) << "malloc creator_arrays_ failed."; - } else { - for (int i = 0; i < array_size_; ++i) { - creator_arrays_[i] = nullptr; - } - } -} - -KernelRegistry::~KernelRegistry() { FreeCreatorArray(); } - KernelRegistry *KernelRegistry::GetInstance() { static KernelRegistry instance; return &instance; @@ -69,18 +47,7 @@ int KernelRegistry::Init() { return RET_OK; } -void KernelRegistry::FreeCreatorArray() { - if (creator_arrays_ != nullptr) { - free(creator_arrays_); - creator_arrays_ = nullptr; - } -} - kernel::KernelCreator KernelRegistry::GetCreator(const KernelKey &desc) { - if (creator_arrays_ == nullptr) { - MS_LOG(ERROR) << "Creator func array is null."; - return nullptr; - } int index = GetCreatorFuncIndex(desc); if (index >= array_size_) { MS_LOG(ERROR) << "invalid kernel key, arch " << desc.arch << ", data_type" << desc.data_type << ",op type " @@ -104,20 +71,17 @@ int KernelRegistry::GetCreatorFuncIndex(const kernel::KernelKey desc) { } void KernelRegistry::RegKernel(const KernelKey desc, kernel::KernelCreator creator) { - if (creator_arrays_ == nullptr) { - MS_LOG(ERROR) << "Creator func array is null."; + int index = GetCreatorFuncIndex(desc); + if (index >= array_size_) { + MS_LOG(ERROR) << "invalid kernel key, arch " << desc.arch << ", data_type" << desc.data_type << ",op type " + << desc.type; return; } - int index = GetCreatorFuncIndex(desc); creator_arrays_[index] = creator; } void KernelRegistry::RegKernel(const KERNEL_ARCH arch, const TypeId data_type, const schema::PrimitiveType op_type, kernel::KernelCreator creator) { - if (creator_arrays_ == nullptr) { - MS_LOG(ERROR) << "Creator func array is null."; - return; - } KernelKey desc = {arch, data_type, op_type}; int index = GetCreatorFuncIndex(desc); if (index >= array_size_) { diff --git a/mindspore/lite/src/kernel_registry.h b/mindspore/lite/src/kernel_registry.h index 507138c65c..c995b01230 100644 --- a/mindspore/lite/src/kernel_registry.h +++ b/mindspore/lite/src/kernel_registry.h @@ -23,11 +23,16 @@ #include "src/lite_kernel.h" #include "schema/model_generated.h" +using mindspore::kernel::kKernelArch_MAX; +using mindspore::kernel::kKernelArch_MIN; +using mindspore::schema::PrimitiveType_MAX; +using mindspore::schema::PrimitiveType_MIN; + namespace mindspore::lite { class KernelRegistry { public: - KernelRegistry(); - virtual ~KernelRegistry(); + KernelRegistry() = default; + virtual ~KernelRegistry() = default; static KernelRegistry *GetInstance(); int Init(); @@ -44,11 +49,11 @@ class KernelRegistry { const Context *ctx, const kernel::KernelKey &key); protected: - kernel::KernelCreator *creator_arrays_ = nullptr; - size_t array_size_; - int device_type_length_; - int data_type_length_; - int op_type_length_; + static const int device_type_length_{kKernelArch_MAX - kKernelArch_MIN + 1}; + static const int data_type_length_{kNumberTypeEnd - kNumberTypeBegin + 1}; + static const int op_type_length_{PrimitiveType_MAX - PrimitiveType_MIN + 1}; + static const int array_size_{device_type_length_ * data_type_length_ * op_type_length_}; + kernel::KernelCreator creator_arrays_[array_size_] = {0}; }; class KernelRegistrar {