diff --git a/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_fp16.cc b/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_fp16.cc index f35d27c3f3..01660faaa9 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_fp16.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_fp16.cc @@ -221,7 +221,7 @@ void FreeMemoryFp16(const std::vector &group_convs, const } } -lite::Tensor *CreateInputTensor(TypeId data_type, std::vector in_shape, bool infered_flag) { +lite::Tensor *CreateInputTensorFp16(TypeId data_type, std::vector in_shape, bool infered_flag) { auto in_tensor = new (std::nothrow) lite::Tensor(data_type, in_shape, Format_NHWC, lite::Tensor::Category::VAR); if (in_tensor == nullptr) { MS_LOG(ERROR) << "new in_tensor failed."; @@ -238,8 +238,8 @@ lite::Tensor *CreateInputTensor(TypeId data_type, std::vector in_shape, boo return in_tensor; } -lite::Tensor *CreateFilterTensor(TypeId data_type, std::vector filter_shape, - const std::vector &inputs, int copy_length, int index) { +lite::Tensor *CreateFilterTensorFp16(TypeId data_type, std::vector filter_shape, + const std::vector &inputs, int copy_length, int index) { auto filter_tensor = new (std::nothrow) lite::Tensor(data_type, filter_shape, Format_NHWC, lite::Tensor::Category::CONST_TENSOR); if (filter_tensor == nullptr) { @@ -263,8 +263,8 @@ lite::Tensor *CreateFilterTensor(TypeId data_type, std::vector filter_shape return filter_tensor; } -lite::Tensor *CreateBiasTensor(TypeId data_type, std::vector bias_shape, const std::vector &inputs, - int new_out_channel, int index) { +lite::Tensor *CreateBiasTensorFp16(TypeId data_type, std::vector bias_shape, + const std::vector &inputs, int new_out_channel, int index) { auto *origin_bias = inputs.at(kBiasIndex)->data_c(); auto bias_tensor = new (std::nothrow) lite::Tensor(data_type, bias_shape, Format_NHWC, lite::Tensor::Category::CONST_TENSOR); @@ -289,8 +289,8 @@ lite::Tensor *CreateBiasTensor(TypeId data_type, std::vector bias_shape, co return bias_tensor; } -lite::Tensor *CreateOutputTensor(std::vector out_shape, const std::vector &outputs, - bool infered_flag, int index) { +lite::Tensor *CreateOutputTensorFp16(std::vector out_shape, const std::vector &outputs, + bool infered_flag, int index) { auto out_tensor = new (std::nothrow) lite::Tensor(); if (out_tensor == nullptr) { MS_LOG(ERROR) << "new tmp_out_tensor failed."; @@ -356,7 +356,7 @@ kernel::LiteKernel *CpuGroupConvFp16KernelCreator(const std::vectorkernel_h_ * conv_param->kernel_w_ * new_in_channel * new_out_channel; - auto filter_tensor = CreateFilterTensor(inputs.at(kWeightIndex)->data_type(), filter_shape, inputs, copy_length, i); + auto filter_tensor = + CreateFilterTensorFp16(inputs.at(kWeightIndex)->data_type(), filter_shape, inputs, copy_length, i); if (filter_tensor == nullptr) { delete new_conv_parameter; FreeMemoryFp16(group_convs, new_inputs, new_outputs); @@ -378,7 +379,8 @@ kernel::LiteKernel *CpuGroupConvFp16KernelCreator(const std::vectordata_type(), bias_shape, inputs, new_out_channel, i); + auto bias_tensor = + CreateBiasTensorFp16(inputs.at(kBiasIndex)->data_type(), bias_shape, inputs, new_out_channel, i); if (bias_tensor == nullptr) { delete new_conv_parameter; FreeMemoryFp16(group_convs, new_inputs, new_outputs); @@ -390,7 +392,7 @@ kernel::LiteKernel *CpuGroupConvFp16KernelCreator(const std::vector &group_convs, const std::vector &new_inputs, - const std::vector &new_outputs) { +void FreeMemory(const std::vector &group_convs, const std::vector &new_inputs, + const std::vector &new_outputs) { for (auto sub_conv : group_convs) { if (sub_conv != nullptr) { delete sub_conv; @@ -187,7 +187,7 @@ void FreeMemoryFp32(const std::vector &group_convs, const } } -lite::Tensor *CreateInputTensorFp32(TypeId data_type, std::vector in_shape, bool infered_flag) { +lite::Tensor *CreateInputTensor(TypeId data_type, std::vector in_shape, bool infered_flag) { auto in_tensor = new (std::nothrow) lite::Tensor(data_type, in_shape, Format_NHWC, lite::Tensor::Category::VAR); if (in_tensor == nullptr) { MS_LOG(ERROR) << "new in_tensor failed."; @@ -247,8 +247,8 @@ lite::Tensor *CreateBiasTensorFp32(TypeId data_type, std::vector bias_shape return bias_tensor; } -lite::Tensor *CreateOutputTensorFp32(std::vector out_shape, const std::vector &outputs, - bool infered_flag, int index) { +lite::Tensor *CreateOutputTensor(std::vector out_shape, const std::vector &outputs, + bool infered_flag, int index) { auto out_tensor = new (std::nothrow) lite::Tensor(); if (out_tensor == nullptr) { MS_LOG(ERROR) << "new tmp_out_tensor failed."; @@ -324,16 +324,16 @@ kernel::LiteKernel *CpuGroupConvFp32KernelCreator(const std::vector new_outputs; auto new_conv_parameter = CreateNewConvParameter(conv_param); if (new_conv_parameter == nullptr) { - FreeMemoryFp32(group_convs, new_inputs, new_outputs); + FreeMemory(group_convs, new_inputs, new_outputs); MS_LOG(ERROR) << "Get new conv parameter failed."; return nullptr; } // create new input for each group - auto in_tensor = CreateInputTensorFp32(inputs.front()->data_type(), in_shape, infered_flag); + auto in_tensor = CreateInputTensor(inputs.front()->data_type(), in_shape, infered_flag); if (in_tensor == nullptr) { delete new_conv_parameter; - FreeMemoryFp32(group_convs, new_inputs, new_outputs); + FreeMemory(group_convs, new_inputs, new_outputs); MS_LOG(ERROR) << "create input tensor failed."; return nullptr; } @@ -345,7 +345,7 @@ kernel::LiteKernel *CpuGroupConvFp32KernelCreator(const std::vectordata_type(), filter_shape, inputs, copy_length, i); if (filter_tensor == nullptr) { delete new_conv_parameter; - FreeMemoryFp32(group_convs, new_inputs, new_outputs); + FreeMemory(group_convs, new_inputs, new_outputs); MS_LOG(ERROR) << "create filter tensor failed."; return nullptr; } @@ -357,7 +357,7 @@ kernel::LiteKernel *CpuGroupConvFp32KernelCreator(const std::vectordata_type(), bias_shape, inputs, new_out_channel, i); if (bias_tensor == nullptr) { delete new_conv_parameter; - FreeMemoryFp32(group_convs, new_inputs, new_outputs); + FreeMemory(group_convs, new_inputs, new_outputs); MS_LOG(ERROR) << "create bias_tensor failed."; return nullptr; } @@ -366,10 +366,10 @@ kernel::LiteKernel *CpuGroupConvFp32KernelCreator(const std::vector &group_convs, const std::vector &new_inputs, + const std::vector &new_outputs); + +ConvParameter *CreateNewConvParameter(ConvParameter *parameter); + +lite::Tensor *CreateInputTensor(TypeId data_type, std::vector in_shape, bool infered_flag); + +lite::Tensor *CreateOutputTensor(std::vector out_shape, const std::vector &outputs, + bool infered_flag, int index); } // namespace mindspore::kernel #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONVOLUTION_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/group_convolution_fp32.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/group_convolution_fp32.cc index c4ff456a10..23e475c626 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/group_convolution_fp32.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/group_convolution_fp32.cc @@ -28,6 +28,11 @@ using mindspore::schema::PrimitiveType_Conv2D; namespace mindspore::kernel { int GroupConvolutionCPUKernel::Init() { for (int i = 0; i < group_num_; ++i) { + auto sub_conv = group_convs_.at(i); + if (sub_conv == nullptr) { + MS_LOG(ERROR) << "sub con " << i << " is null."; + return RET_ERROR; + } auto ret = group_convs_.at(i)->Init(); if (ret != RET_OK) { MS_LOG(ERROR) << "Sub kernel init failed."; @@ -127,7 +132,7 @@ int GroupConvolutionCPUKernel::PreProcess() { auto ret = output->MallocData(); if (ret != RET_OK) { FreeSubKernel(); - MS_LOG(ERROR) << "fp32 group conv out tensor malloc data failed."; + MS_LOG(ERROR) << "group conv out tensor malloc data failed."; return ret; } } diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/group_convolution_fp32.h b/mindspore/lite/src/runtime/kernel/arm/fp32/group_convolution_fp32.h index c00cd71726..fdfe8dce70 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/group_convolution_fp32.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/group_convolution_fp32.h @@ -41,15 +41,17 @@ class GroupConvolutionCPUKernel : public ConvolutionBaseCPUKernel { int ReSize() override; int Run() override; int PreProcess() override; - void SeparateInput(int group_id); - void PostConcat(int group_id); + virtual void SeparateInput(int group_id); + virtual void PostConcat(int group_id); void FreeSubKernel(); - private: + protected: std::vector group_convs_; + const int group_num_; + + private: float *ori_in_data_ = nullptr; // do not free float *ori_out_data_ = nullptr; // do not free - const int group_num_; }; } // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/convolution_int8.cc b/mindspore/lite/src/runtime/kernel/arm/int8/convolution_int8.cc index a6f9359d79..33bb636dc1 100644 --- a/mindspore/lite/src/runtime/kernel/arm/int8/convolution_int8.cc +++ b/mindspore/lite/src/runtime/kernel/arm/int8/convolution_int8.cc @@ -20,8 +20,10 @@ #include "schema/model_generated.h" #include "src/kernel_registry.h" #include "src/runtime/kernel/arm/base/layout_transform.h" +#include "src/runtime/kernel/arm/fp32/convolution_fp32.h" #include "src/runtime/kernel/arm/int8/convolution_1x1_int8.h" #include "src/runtime/kernel/arm/int8/convolution_3x3_int8.h" +#include "src/runtime/kernel/arm/int8/group_convolution_int8.h" #include "src/runtime/runtime_api.h" #ifdef ENABLE_ARM64 #include "src/runtime/kernel/arm/int8/opt_op_handler.h" @@ -32,6 +34,7 @@ using mindspore::lite::KernelRegistrar; using mindspore::lite::RET_ERROR; using mindspore::lite::RET_OK; using mindspore::schema::PrimitiveType_Conv2D; +using mindspore::schema::Format::Format_NHWC; namespace mindspore::kernel { void ConvolutionInt8CPUKernel::CheckSupportOptimize() { @@ -242,6 +245,166 @@ int ConvolutionInt8CPUKernel::Run() { return RET_OK; } +lite::Tensor *CreateFilterTensorInt8(TypeId data_type, std::vector filter_shape, + const std::vector &inputs, int copy_length, int index) { + MS_ASSERT(data_type == kNumberTypeInt8); + auto filter_tensor = + new (std::nothrow) lite::Tensor(data_type, filter_shape, Format_NHWC, lite::Tensor::Category::CONST_TENSOR); + if (filter_tensor == nullptr) { + MS_LOG(ERROR) << "new filter_tensor failed."; + return nullptr; + } + auto ret = filter_tensor->MallocData(); + if (ret != RET_OK) { + delete filter_tensor; + MS_LOG(ERROR) << "filter_tensor malloc failed."; + return nullptr; + } + auto *origin_weight = reinterpret_cast(inputs.at(kWeightIndex)->data_c()); + memcpy(filter_tensor->data_c(), origin_weight + index * copy_length, copy_length * sizeof(int8_t)); + return filter_tensor; +} + +lite::Tensor *CreateBiasTensorInt8(TypeId data_type, std::vector bias_shape, + const std::vector &inputs, int new_out_channel, int index) { + MS_ASSERT(data_type == kNumberTypeInt32); + auto *origin_bias = inputs.at(kBiasIndex)->data_c(); + auto bias_tensor = + new (std::nothrow) lite::Tensor(data_type, bias_shape, Format_NHWC, lite::Tensor::Category::CONST_TENSOR); + if (bias_tensor == nullptr) { + MS_LOG(ERROR) << "new bias_tensor failed."; + return nullptr; + } + auto ret = bias_tensor->MallocData(); + if (ret != RET_OK) { + delete bias_tensor; + MS_LOG(ERROR) << "bias_tensor malloc failed."; + return nullptr; + } + auto bias_data = reinterpret_cast(origin_bias); + memcpy(bias_tensor->data_c(), bias_data + index * new_out_channel, new_out_channel * sizeof(int32_t)); + return bias_tensor; +} + +kernel::LiteKernel *CpuConvInt8KernelSelect(const std::vector &inputs, + const std::vector &outputs, OpParameter *op_parameter, + const InnerContext *ctx, const mindspore::lite::PrimitiveC *primitive) { + auto conv_param = reinterpret_cast(op_parameter); + kernel::LiteKernel *kernel = nullptr; + if (conv_param->kernel_h_ == 3 && conv_param->kernel_w_ == 3 && conv_param->stride_h_ == 1 && + conv_param->stride_w_ == 1 && conv_param->dilation_h_ == 1 && conv_param->dilation_w_ == 1) { +#ifdef ENABLE_ARM64 + if (mindspore::lite::IsSupportSDot()) { + kernel = new (std::nothrow) kernel::ConvolutionInt8CPUKernel(op_parameter, inputs, outputs, ctx, primitive); + } else { + kernel = new (std::nothrow) kernel::Convolution3x3Int8CPUKernel(op_parameter, inputs, outputs, ctx, primitive); + } +#else + kernel = new (std::nothrow) kernel::Convolution3x3Int8CPUKernel(op_parameter, inputs, outputs, ctx, primitive); +#endif + } else if (conv_param->kernel_h_ == 1 && conv_param->kernel_w_ == 1) { + kernel = new (std::nothrow) kernel::Convolution1x1Int8CPUKernel(op_parameter, inputs, outputs, ctx, primitive); + } else { + kernel = new (std::nothrow) kernel::ConvolutionInt8CPUKernel(op_parameter, inputs, outputs, ctx, primitive); + } + return kernel; +} + +kernel::LiteKernel *CpuGroupConvInt8KernelCreator(const std::vector &inputs, + const std::vector &outputs, OpParameter *op_parameter, + const InnerContext *ctx, const mindspore::lite::PrimitiveC *primitive, + int group) { + auto conv_param = reinterpret_cast(op_parameter); + std::vector in_shape; + std::vector out_shape; + int new_in_channel = inputs.at(kWeightIndex)->Channel(); + int new_out_channel = 0; + if (group == 0) { + MS_LOG(ERROR) << "Divisor 'group' cannot be 0."; + return nullptr; + } else { + new_out_channel = inputs.at(kWeightIndex)->Batch() / group; + } + bool infered_flag = primitive != nullptr && primitive->infer_flag(); + if (infered_flag) { + int batch = inputs.front()->Batch(); + int in_h = inputs.front()->Height(); + int in_w = inputs.front()->Width(); + conv_param->input_channel_ = new_in_channel; + conv_param->output_channel_ = new_out_channel; + in_shape = {batch, in_h, in_w, new_in_channel}; + out_shape = {batch, conv_param->output_h_, conv_param->output_w_, new_out_channel}; + } + std::vector filter_shape = {new_out_channel, conv_param->kernel_h_, conv_param->kernel_w_, new_in_channel}; + std::vector bias_shape = {new_out_channel}; + + // create sub kernels + std::vector group_convs; + for (int i = 0; i < group; ++i) { + std::vector new_inputs; + std::vector new_outputs; + auto new_conv_parameter = CreateNewConvParameter(conv_param); + if (new_conv_parameter == nullptr) { + FreeMemory(group_convs, new_inputs, new_outputs); + MS_LOG(ERROR) << "Get new conv parameter failed."; + return nullptr; + } + + // create new input for each group + auto input_data_type = inputs.front()->data_type(); + MS_ASSERT(input_data_type == kNumberTypeInt8); + auto in_tensor = CreateInputTensor(input_data_type, in_shape, infered_flag); + if (in_tensor == nullptr) { + delete new_conv_parameter; + FreeMemory(group_convs, new_inputs, new_outputs); + MS_LOG(ERROR) << "create input tensor failed."; + return nullptr; + } + new_inputs.emplace_back(in_tensor); + + // create new weight + int copy_length = conv_param->kernel_h_ * conv_param->kernel_w_ * new_in_channel * new_out_channel; + auto filter_tensor = + CreateFilterTensorInt8(inputs.at(kWeightIndex)->data_type(), filter_shape, inputs, copy_length, i); + if (filter_tensor == nullptr) { + delete new_conv_parameter; + FreeMemory(group_convs, new_inputs, new_outputs); + MS_LOG(ERROR) << "create filter tensor failed."; + return nullptr; + } + new_inputs.emplace_back(filter_tensor); + + // if has bias, create new bias + if (inputs.size() == 3) { + auto bias_tensor = + CreateBiasTensorInt8(inputs.at(kBiasIndex)->data_type(), bias_shape, inputs, new_out_channel, i); + if (bias_tensor == nullptr) { + delete new_conv_parameter; + FreeMemory(group_convs, new_inputs, new_outputs); + MS_LOG(ERROR) << "create bias_tensor failed."; + return nullptr; + } + new_inputs.emplace_back(bias_tensor); + } + + // create new output tensor + for (size_t j = 0; j < outputs.size(); ++j) { + auto out_tensor = CreateOutputTensor(out_shape, outputs, infered_flag, j); + if (out_tensor == nullptr) { + delete new_conv_parameter; + FreeMemory(group_convs, new_inputs, new_outputs); + MS_LOG(ERROR) << "new out_tensor failed."; + return nullptr; + } + new_outputs.emplace_back(out_tensor); + } + group_convs.emplace_back(CpuConvInt8KernelSelect( + new_inputs, new_outputs, reinterpret_cast(new_conv_parameter), ctx, primitive)); + } + return new (std::nothrow) + GroupConvolutionInt8CPUKernel(op_parameter, inputs, outputs, ctx, primitive, group_convs, group); +} + kernel::LiteKernel *CpuConvInt8KernelCreator(const std::vector &inputs, const std::vector &outputs, OpParameter *opParameter, const InnerContext *ctx, const kernel::KernelKey &desc, @@ -249,27 +412,12 @@ kernel::LiteKernel *CpuConvInt8KernelCreator(const std::vector & MS_ASSERT(opParameter != nullptr); MS_ASSERT(desc.type == schema::PrimitiveType_Conv2D); auto conv_param = reinterpret_cast(opParameter); - int kernel_h = conv_param->kernel_h_; - int kernel_w = conv_param->kernel_w_; - int stride_h = conv_param->stride_h_; - int stride_w = conv_param->stride_w_; - int dilation_h = conv_param->dilation_h_; - int dilation_w = conv_param->dilation_w_; - kernel::LiteKernel *kernel; - if (kernel_h == 3 && kernel_w == 3 && stride_h == 1 && stride_w == 1 && dilation_h == 1 && dilation_w == 1) { -#ifdef ENABLE_ARM64 - if (mindspore::lite::IsSupportSDot()) { - kernel = new (std::nothrow) kernel::ConvolutionInt8CPUKernel(opParameter, inputs, outputs, ctx, primitive); - } else { - kernel = new (std::nothrow) kernel::Convolution3x3Int8CPUKernel(opParameter, inputs, outputs, ctx, primitive); - } -#else - kernel = new (std::nothrow) kernel::Convolution3x3Int8CPUKernel(opParameter, inputs, outputs, ctx, primitive); -#endif - } else if (kernel_h == 1 && kernel_w == 1) { - kernel = new (std::nothrow) kernel::Convolution1x1Int8CPUKernel(opParameter, inputs, outputs, ctx, primitive); + kernel::LiteKernel *kernel = nullptr; + if (conv_param->group_ == 1) { + kernel = CpuConvInt8KernelSelect(inputs, outputs, opParameter, ctx, primitive); } else { - kernel = new (std::nothrow) kernel::ConvolutionInt8CPUKernel(opParameter, inputs, outputs, ctx, primitive); + MS_ASSERT(conv_param->group_ > 1); + kernel = CpuGroupConvInt8KernelCreator(inputs, outputs, opParameter, ctx, primitive, conv_param->group_); } if (kernel == nullptr) { MS_LOG(ERROR) << "kernel is nullptr."; diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/group_convolution_int8.cc b/mindspore/lite/src/runtime/kernel/arm/int8/group_convolution_int8.cc new file mode 100644 index 0000000000..e10d25ee76 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/int8/group_convolution_int8.cc @@ -0,0 +1,74 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/int8/group_convolution_int8.h" +#include "schema/model_generated.h" +#include "src/kernel_registry.h" +#include "include/errorcode.h" + +using mindspore::kernel::KERNEL_ARCH::kCPU; +using mindspore::lite::KernelRegistrar; +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_OK; +using mindspore::schema::PrimitiveType_Conv2D; + +namespace mindspore::kernel { +void GroupConvolutionInt8CPUKernel::SeparateInput(int group_id) { + int in_plane = conv_param_->input_h_ * conv_param_->input_w_; + int sub_in_channel = conv_param_->input_channel_; + int ori_in_channel = sub_in_channel * group_num_; + auto sub_in_data = reinterpret_cast(group_convs_.at(group_id)->in_tensors().front()->data_c()); + int8_t *src_ptr = ori_in_data_ + group_id * sub_in_channel; + int8_t *dst_ptr = sub_in_data; + for (int i = 0; i < in_plane; ++i) { + memcpy(dst_ptr, src_ptr, sub_in_channel * sizeof(int8_t)); + src_ptr += ori_in_channel; + dst_ptr += sub_in_channel; + } +} + +void GroupConvolutionInt8CPUKernel::PostConcat(int group_id) { + int out_plane = conv_param_->output_h_ * conv_param_->output_w_; + int sub_out_channel = conv_param_->output_channel_; + int ori_out_channel = sub_out_channel * group_num_; + auto sub_out_data = reinterpret_cast(group_convs_.at(group_id)->out_tensors().front()->data_c()); + int8_t *src_ptr = sub_out_data; + int8_t *dst_ptr = ori_out_data_ + group_id * sub_out_channel; + for (int i = 0; i < out_plane; ++i) { + memcpy(dst_ptr, src_ptr, sub_out_channel * sizeof(int8_t)); + src_ptr += sub_out_channel; + dst_ptr += ori_out_channel; + } +} + +int GroupConvolutionInt8CPUKernel::Run() { + ori_in_data_ = reinterpret_cast(in_tensors().front()->data_c()); + ori_out_data_ = reinterpret_cast(out_tensors().front()->data_c()); + for (int i = 0; i < group_num_; ++i) { + // first, separate group conv input into several parts. This step must be in runtime stage. + SeparateInput(i); + // sun kernels run + auto ret = group_convs_.at(i)->Run(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "sub kernel " << i << " execute failed."; + return ret; + } + // post process, concat all outputs of sub-kernels into one output + PostConcat(i); + } + return RET_OK; +} +} // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/group_convolution_int8.h b/mindspore/lite/src/runtime/kernel/arm/int8/group_convolution_int8.h new file mode 100644 index 0000000000..2ef3444c88 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/int8/group_convolution_int8.h @@ -0,0 +1,48 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_GROUP_CONVOLUTION_INT8_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_GROUP_CONVOLUTION_INT8_H_ + +#include +#include +#include "src/lite_kernel.h" +#include "nnacl/op_base.h" +#include "src/runtime/kernel/arm/fp32/group_convolution_fp32.h" + +namespace mindspore::kernel { +class GroupConvolutionInt8CPUKernel : public GroupConvolutionCPUKernel { + public: + GroupConvolutionInt8CPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs, const lite::InnerContext *ctx, + const mindspore::lite::PrimitiveC *primitive, + std::vector group_convs, const int group_num) + : GroupConvolutionCPUKernel(parameter, inputs, outputs, ctx, primitive, group_convs, group_num) { + } // opParameter(in channel, out channel) in this kernel has been split to groups, if + // you want to get real params, multiply in channel / out channel with group num + ~GroupConvolutionInt8CPUKernel() override { GroupConvolutionCPUKernel::FreeSubKernel(); } + + int Run() override; + void SeparateInput(int group_id) override; + void PostConcat(int group_id) override; + + private: + int8_t *ori_in_data_ = nullptr; // do not free + int8_t *ori_out_data_ = nullptr; // do not free +}; +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_GROUP_CONVOLUTION_INT8_H_