|
|
|
@ -35,30 +35,15 @@ kernel::LiteKernel *CpuMatmulKernelCreator(const std::vector<lite::tensor::Tenso
|
|
|
|
|
auto input_tensor = inputs.at(kInputIndex);
|
|
|
|
|
auto data_type = input_tensor->data_type();
|
|
|
|
|
kernel::LiteKernel *kernel = nullptr;
|
|
|
|
|
switch (data_type) {
|
|
|
|
|
case kNumberTypeInt8:
|
|
|
|
|
case kNumberTypeUInt8: {
|
|
|
|
|
kernel = new (std::nothrow) MatmulInt8CPUKernel(opParameter, inputs, outputs, ctx, primitive);
|
|
|
|
|
if (kernel == nullptr) {
|
|
|
|
|
MS_LOG(ERROR) << "kernel is nullptr.";
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
case kNumberTypeFloat32: {
|
|
|
|
|
kernel = new (std::nothrow) MatmulCPUKernel(opParameter, inputs, outputs, ctx, primitive);
|
|
|
|
|
if (kernel == nullptr) {
|
|
|
|
|
MS_LOG(ERROR) << "kernel is nullptr.";
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
default:
|
|
|
|
|
break;
|
|
|
|
|
if (data_type == kNumberTypeInt8 || data_type == kNumberTypeUInt8) {
|
|
|
|
|
kernel = new (std::nothrow) MatmulInt8CPUKernel(opParameter, inputs, outputs, ctx, primitive);
|
|
|
|
|
} else {
|
|
|
|
|
kernel = new (std::nothrow) MatmulCPUKernel(opParameter, inputs, outputs, ctx, primitive);
|
|
|
|
|
}
|
|
|
|
|
if (kernel == nullptr) {
|
|
|
|
|
MS_LOG(ERROR) << "kernel is nullptr.";
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto ret = kernel->Init();
|
|
|
|
|
if (ret != RET_OK) {
|
|
|
|
|
delete kernel;
|
|
|
|
|