|
|
|
@ -28,30 +28,8 @@ using mindspore::kernel::kCPU;
|
|
|
|
|
using mindspore::kernel::KERNEL_ARCH;
|
|
|
|
|
using mindspore::kernel::KernelCreator;
|
|
|
|
|
using mindspore::kernel::KernelKey;
|
|
|
|
|
using mindspore::kernel::kKernelArch_MAX;
|
|
|
|
|
using mindspore::kernel::kKernelArch_MIN;
|
|
|
|
|
using mindspore::schema::PrimitiveType_MAX;
|
|
|
|
|
using mindspore::schema::PrimitiveType_MIN;
|
|
|
|
|
|
|
|
|
|
namespace mindspore::lite {
|
|
|
|
|
KernelRegistry::KernelRegistry() {
|
|
|
|
|
device_type_length_ = kKernelArch_MAX - kKernelArch_MIN + 1;
|
|
|
|
|
data_type_length_ = kNumberTypeEnd - kNumberTypeBegin + 1;
|
|
|
|
|
op_type_length_ = PrimitiveType_MAX - PrimitiveType_MIN + 1;
|
|
|
|
|
// malloc an array contain creator functions of kernel
|
|
|
|
|
array_size_ = device_type_length_ * data_type_length_ * op_type_length_;
|
|
|
|
|
creator_arrays_ = (kernel::KernelCreator *)malloc(array_size_ * sizeof(kernel::KernelCreator));
|
|
|
|
|
if (creator_arrays_ == nullptr) {
|
|
|
|
|
MS_LOG(ERROR) << "malloc creator_arrays_ failed.";
|
|
|
|
|
} else {
|
|
|
|
|
for (int i = 0; i < array_size_; ++i) {
|
|
|
|
|
creator_arrays_[i] = nullptr;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
KernelRegistry::~KernelRegistry() { FreeCreatorArray(); }
|
|
|
|
|
|
|
|
|
|
KernelRegistry *KernelRegistry::GetInstance() {
|
|
|
|
|
static KernelRegistry instance;
|
|
|
|
|
return &instance;
|
|
|
|
@ -69,18 +47,7 @@ int KernelRegistry::Init() {
|
|
|
|
|
return RET_OK;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void KernelRegistry::FreeCreatorArray() {
|
|
|
|
|
if (creator_arrays_ != nullptr) {
|
|
|
|
|
free(creator_arrays_);
|
|
|
|
|
creator_arrays_ = nullptr;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
kernel::KernelCreator KernelRegistry::GetCreator(const KernelKey &desc) {
|
|
|
|
|
if (creator_arrays_ == nullptr) {
|
|
|
|
|
MS_LOG(ERROR) << "Creator func array is null.";
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
int index = GetCreatorFuncIndex(desc);
|
|
|
|
|
if (index >= array_size_) {
|
|
|
|
|
MS_LOG(ERROR) << "invalid kernel key, arch " << desc.arch << ", data_type" << desc.data_type << ",op type "
|
|
|
|
@ -104,20 +71,17 @@ int KernelRegistry::GetCreatorFuncIndex(const kernel::KernelKey desc) {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void KernelRegistry::RegKernel(const KernelKey desc, kernel::KernelCreator creator) {
|
|
|
|
|
if (creator_arrays_ == nullptr) {
|
|
|
|
|
MS_LOG(ERROR) << "Creator func array is null.";
|
|
|
|
|
int index = GetCreatorFuncIndex(desc);
|
|
|
|
|
if (index >= array_size_) {
|
|
|
|
|
MS_LOG(ERROR) << "invalid kernel key, arch " << desc.arch << ", data_type" << desc.data_type << ",op type "
|
|
|
|
|
<< desc.type;
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
int index = GetCreatorFuncIndex(desc);
|
|
|
|
|
creator_arrays_[index] = creator;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void KernelRegistry::RegKernel(const KERNEL_ARCH arch, const TypeId data_type, const schema::PrimitiveType op_type,
|
|
|
|
|
kernel::KernelCreator creator) {
|
|
|
|
|
if (creator_arrays_ == nullptr) {
|
|
|
|
|
MS_LOG(ERROR) << "Creator func array is null.";
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
KernelKey desc = {arch, data_type, op_type};
|
|
|
|
|
int index = GetCreatorFuncIndex(desc);
|
|
|
|
|
if (index >= array_size_) {
|
|
|
|
|