!14038 [MSLITE][DEVELOP] fix bug of npu convolution

From: @yangruoqi713
Reviewed-by: @zhang_xue_tong,@zhanghaibo5
Signed-off-by: @zhang_xue_tong
pull/14038/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit 3d2093b666

@ -25,6 +25,10 @@ using mindspore::schema::PrimitiveType_Conv2DFusion;
namespace mindspore::kernel {
int ConvolutionNPUKernel::IsSupport(const std::vector<lite::Tensor *> &inputs,
const std::vector<lite::Tensor *> &outputs, OpParameter *opParameter) {
if (conv_param_->stride_h_ > inputs[0]->Height() || conv_param_->stride_w_ > inputs[0]->Width()) {
MS_LOG(ERROR) << "Npu convolution does not support stride greater than input size.";
return RET_ERROR;
}
return RET_OK;
}
@ -108,10 +112,14 @@ kernel::LiteKernel *NpuConvKernelCreator(const std::vector<lite::Tensor *> &inpu
const lite::InnerContext *ctx, const kernel::KernelKey &desc) {
MS_ASSERT(op_parameter != nullptr);
MS_ASSERT(desc.type == schema::PrimitiveType_Conv2DFusion);
if (inputs[0]->Size() > NPU_MEMORY_MAX) {
MS_LOG(ERROR) << "Npu does not support input tensor size greater than 200MB";
free(op_parameter);
return nullptr;
}
auto conv_param = reinterpret_cast<ConvParameter *>(op_parameter);
kernel::NPUKernel *kernel = nullptr;
if (conv_param->group_ == 1) {
kernel = new (std::nothrow) kernel::ConvolutionNPUKernel(op_parameter, inputs, outputs, ctx);
} else if (conv_param->group_ == conv_param->input_channel_ && conv_param->group_ == conv_param->output_channel_) {

@ -27,6 +27,7 @@ using mindspore::kernel::LiteKernel;
using mindspore::lite::RET_ERROR;
using mindspore::lite::RET_OK;
namespace mindspore::kernel {
#define NPU_MEMORY_MAX 200 * 1024 * 1024
class NPUKernel : public LiteKernel {
public:
NPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
@ -62,6 +63,11 @@ kernel::LiteKernel *NPUKernelCreator(const std::vector<lite::Tensor *> &inputs,
free(op_parameter);
return nullptr;
}
if (inputs[0]->Size() > NPU_MEMORY_MAX) {
MS_LOG(ERROR) << "Npu does not support input tensor size greater than 200MB";
free(op_parameter);
return nullptr;
}
auto *kernel = new (std::nothrow) T(op_parameter, inputs, outputs, ctx);
if (kernel == nullptr) {
MS_LOG(ERROR) << "kernel " << op_parameter->name_ << "is nullptr.";

Loading…
Cancel
Save