!4674 [MS][LITE][Develop]refactor kernel creator

Merge pull request !4674 from sunsuodong/refactor_kernel_creator
pull/4674/MERGE
mindspore-ci-bot 5 years ago committed by Gitee
commit e60c0b605c

@ -28,30 +28,8 @@ using mindspore::kernel::kCPU;
using mindspore::kernel::KERNEL_ARCH;
using mindspore::kernel::KernelCreator;
using mindspore::kernel::KernelKey;
using mindspore::kernel::kKernelArch_MAX;
using mindspore::kernel::kKernelArch_MIN;
using mindspore::schema::PrimitiveType_MAX;
using mindspore::schema::PrimitiveType_MIN;
namespace mindspore::lite {
KernelRegistry::KernelRegistry() {
device_type_length_ = kKernelArch_MAX - kKernelArch_MIN + 1;
data_type_length_ = kNumberTypeEnd - kNumberTypeBegin + 1;
op_type_length_ = PrimitiveType_MAX - PrimitiveType_MIN + 1;
// malloc an array contain creator functions of kernel
array_size_ = device_type_length_ * data_type_length_ * op_type_length_;
creator_arrays_ = (kernel::KernelCreator *)malloc(array_size_ * sizeof(kernel::KernelCreator));
if (creator_arrays_ == nullptr) {
MS_LOG(ERROR) << "malloc creator_arrays_ failed.";
} else {
for (int i = 0; i < array_size_; ++i) {
creator_arrays_[i] = nullptr;
}
}
}
KernelRegistry::~KernelRegistry() { FreeCreatorArray(); }
KernelRegistry *KernelRegistry::GetInstance() {
static KernelRegistry instance;
return &instance;
@ -69,18 +47,7 @@ int KernelRegistry::Init() {
return RET_OK;
}
void KernelRegistry::FreeCreatorArray() {
if (creator_arrays_ != nullptr) {
free(creator_arrays_);
creator_arrays_ = nullptr;
}
}
kernel::KernelCreator KernelRegistry::GetCreator(const KernelKey &desc) {
if (creator_arrays_ == nullptr) {
MS_LOG(ERROR) << "Creator func array is null.";
return nullptr;
}
int index = GetCreatorFuncIndex(desc);
if (index >= array_size_) {
MS_LOG(ERROR) << "invalid kernel key, arch " << desc.arch << ", data_type" << desc.data_type << ",op type "
@ -104,20 +71,17 @@ int KernelRegistry::GetCreatorFuncIndex(const kernel::KernelKey desc) {
}
void KernelRegistry::RegKernel(const KernelKey desc, kernel::KernelCreator creator) {
if (creator_arrays_ == nullptr) {
MS_LOG(ERROR) << "Creator func array is null.";
int index = GetCreatorFuncIndex(desc);
if (index >= array_size_) {
MS_LOG(ERROR) << "invalid kernel key, arch " << desc.arch << ", data_type" << desc.data_type << ",op type "
<< desc.type;
return;
}
int index = GetCreatorFuncIndex(desc);
creator_arrays_[index] = creator;
}
void KernelRegistry::RegKernel(const KERNEL_ARCH arch, const TypeId data_type, const schema::PrimitiveType op_type,
kernel::KernelCreator creator) {
if (creator_arrays_ == nullptr) {
MS_LOG(ERROR) << "Creator func array is null.";
return;
}
KernelKey desc = {arch, data_type, op_type};
int index = GetCreatorFuncIndex(desc);
if (index >= array_size_) {

@ -23,11 +23,16 @@
#include "src/lite_kernel.h"
#include "schema/model_generated.h"
using mindspore::kernel::kKernelArch_MAX;
using mindspore::kernel::kKernelArch_MIN;
using mindspore::schema::PrimitiveType_MAX;
using mindspore::schema::PrimitiveType_MIN;
namespace mindspore::lite {
class KernelRegistry {
public:
KernelRegistry();
virtual ~KernelRegistry();
KernelRegistry() = default;
virtual ~KernelRegistry() = default;
static KernelRegistry *GetInstance();
int Init();
@ -44,11 +49,11 @@ class KernelRegistry {
const Context *ctx, const kernel::KernelKey &key);
protected:
kernel::KernelCreator *creator_arrays_ = nullptr;
size_t array_size_;
int device_type_length_;
int data_type_length_;
int op_type_length_;
static const int device_type_length_{kKernelArch_MAX - kKernelArch_MIN + 1};
static const int data_type_length_{kNumberTypeEnd - kNumberTypeBegin + 1};
static const int op_type_length_{PrimitiveType_MAX - PrimitiveType_MIN + 1};
static const int array_size_{device_type_length_ * data_type_length_ * op_type_length_};
kernel::KernelCreator creator_arrays_[array_size_] = {0};
};
class KernelRegistrar {

Loading…
Cancel
Save