!4674 [MS][LITE][Develop]refactor kernel creator

Merge pull request !4674 from sunsuodong/refactor_kernel_creator
pull/4674/MERGE
mindspore-ci-bot 5 years ago committed by Gitee
commit e60c0b605c

@ -28,30 +28,8 @@ using mindspore::kernel::kCPU;
using mindspore::kernel::KERNEL_ARCH; using mindspore::kernel::KERNEL_ARCH;
using mindspore::kernel::KernelCreator; using mindspore::kernel::KernelCreator;
using mindspore::kernel::KernelKey; using mindspore::kernel::KernelKey;
using mindspore::kernel::kKernelArch_MAX;
using mindspore::kernel::kKernelArch_MIN;
using mindspore::schema::PrimitiveType_MAX;
using mindspore::schema::PrimitiveType_MIN;
namespace mindspore::lite { namespace mindspore::lite {
KernelRegistry::KernelRegistry() {
device_type_length_ = kKernelArch_MAX - kKernelArch_MIN + 1;
data_type_length_ = kNumberTypeEnd - kNumberTypeBegin + 1;
op_type_length_ = PrimitiveType_MAX - PrimitiveType_MIN + 1;
// malloc an array contain creator functions of kernel
array_size_ = device_type_length_ * data_type_length_ * op_type_length_;
creator_arrays_ = (kernel::KernelCreator *)malloc(array_size_ * sizeof(kernel::KernelCreator));
if (creator_arrays_ == nullptr) {
MS_LOG(ERROR) << "malloc creator_arrays_ failed.";
} else {
for (int i = 0; i < array_size_; ++i) {
creator_arrays_[i] = nullptr;
}
}
}
KernelRegistry::~KernelRegistry() { FreeCreatorArray(); }
KernelRegistry *KernelRegistry::GetInstance() { KernelRegistry *KernelRegistry::GetInstance() {
static KernelRegistry instance; static KernelRegistry instance;
return &instance; return &instance;
@ -69,18 +47,7 @@ int KernelRegistry::Init() {
return RET_OK; return RET_OK;
} }
void KernelRegistry::FreeCreatorArray() {
if (creator_arrays_ != nullptr) {
free(creator_arrays_);
creator_arrays_ = nullptr;
}
}
kernel::KernelCreator KernelRegistry::GetCreator(const KernelKey &desc) { kernel::KernelCreator KernelRegistry::GetCreator(const KernelKey &desc) {
if (creator_arrays_ == nullptr) {
MS_LOG(ERROR) << "Creator func array is null.";
return nullptr;
}
int index = GetCreatorFuncIndex(desc); int index = GetCreatorFuncIndex(desc);
if (index >= array_size_) { if (index >= array_size_) {
MS_LOG(ERROR) << "invalid kernel key, arch " << desc.arch << ", data_type" << desc.data_type << ",op type " MS_LOG(ERROR) << "invalid kernel key, arch " << desc.arch << ", data_type" << desc.data_type << ",op type "
@ -104,20 +71,17 @@ int KernelRegistry::GetCreatorFuncIndex(const kernel::KernelKey desc) {
} }
void KernelRegistry::RegKernel(const KernelKey desc, kernel::KernelCreator creator) { void KernelRegistry::RegKernel(const KernelKey desc, kernel::KernelCreator creator) {
if (creator_arrays_ == nullptr) { int index = GetCreatorFuncIndex(desc);
MS_LOG(ERROR) << "Creator func array is null."; if (index >= array_size_) {
MS_LOG(ERROR) << "invalid kernel key, arch " << desc.arch << ", data_type" << desc.data_type << ",op type "
<< desc.type;
return; return;
} }
int index = GetCreatorFuncIndex(desc);
creator_arrays_[index] = creator; creator_arrays_[index] = creator;
} }
void KernelRegistry::RegKernel(const KERNEL_ARCH arch, const TypeId data_type, const schema::PrimitiveType op_type, void KernelRegistry::RegKernel(const KERNEL_ARCH arch, const TypeId data_type, const schema::PrimitiveType op_type,
kernel::KernelCreator creator) { kernel::KernelCreator creator) {
if (creator_arrays_ == nullptr) {
MS_LOG(ERROR) << "Creator func array is null.";
return;
}
KernelKey desc = {arch, data_type, op_type}; KernelKey desc = {arch, data_type, op_type};
int index = GetCreatorFuncIndex(desc); int index = GetCreatorFuncIndex(desc);
if (index >= array_size_) { if (index >= array_size_) {

@ -23,11 +23,16 @@
#include "src/lite_kernel.h" #include "src/lite_kernel.h"
#include "schema/model_generated.h" #include "schema/model_generated.h"
using mindspore::kernel::kKernelArch_MAX;
using mindspore::kernel::kKernelArch_MIN;
using mindspore::schema::PrimitiveType_MAX;
using mindspore::schema::PrimitiveType_MIN;
namespace mindspore::lite { namespace mindspore::lite {
class KernelRegistry { class KernelRegistry {
public: public:
KernelRegistry(); KernelRegistry() = default;
virtual ~KernelRegistry(); virtual ~KernelRegistry() = default;
static KernelRegistry *GetInstance(); static KernelRegistry *GetInstance();
int Init(); int Init();
@ -44,11 +49,11 @@ class KernelRegistry {
const Context *ctx, const kernel::KernelKey &key); const Context *ctx, const kernel::KernelKey &key);
protected: protected:
kernel::KernelCreator *creator_arrays_ = nullptr; static const int device_type_length_{kKernelArch_MAX - kKernelArch_MIN + 1};
size_t array_size_; static const int data_type_length_{kNumberTypeEnd - kNumberTypeBegin + 1};
int device_type_length_; static const int op_type_length_{PrimitiveType_MAX - PrimitiveType_MIN + 1};
int data_type_length_; static const int array_size_{device_type_length_ * data_type_length_ * op_type_length_};
int op_type_length_; kernel::KernelCreator creator_arrays_[array_size_] = {0};
}; };
class KernelRegistrar { class KernelRegistrar {

Loading…
Cancel
Save