diff --git a/mindspore/lite/src/runtime/kernel/arm/base/convolution_base.cc b/mindspore/lite/src/runtime/kernel/arm/base/convolution_base.cc index 116976e6c3..a7e01bc624 100644 --- a/mindspore/lite/src/runtime/kernel/arm/base/convolution_base.cc +++ b/mindspore/lite/src/runtime/kernel/arm/base/convolution_base.cc @@ -34,10 +34,6 @@ ConvolutionBaseCPUKernel::~ConvolutionBaseCPUKernel() { free(bias_data_); bias_data_ = nullptr; } - if (nhwc4_input_ != nullptr) { - free(nhwc4_input_); - nhwc4_input_ = nullptr; - } } void ConvolutionBaseCPUKernel::FreeQuantParam() { @@ -112,18 +108,6 @@ int ConvolutionBaseCPUKernel::CheckResizeValid() { return RET_OK; } -int ConvolutionBaseCPUKernel::CheckLayout(lite::Tensor *input_tensor) { - auto data_type = input_tensor->data_type(); - auto input_format = input_tensor->GetFormat(); - schema::Format execute_format = schema::Format::Format_NHWC4; - convert_func_ = LayoutTransform(data_type, input_format, execute_format); - if (convert_func_ == nullptr) { - MS_LOG(ERROR) << "layout convert func is nullptr."; - return RET_ERROR; - } - return RET_OK; -} - int ConvolutionBaseCPUKernel::SetIfPerChannel() { auto filter_tensor = in_tensors_.at(kWeightIndex); auto input_channel = filter_tensor->Channel(); diff --git a/mindspore/lite/src/runtime/kernel/arm/base/convolution_base.h b/mindspore/lite/src/runtime/kernel/arm/base/convolution_base.h index 1d58abc978..11a2ac1e1c 100644 --- a/mindspore/lite/src/runtime/kernel/arm/base/convolution_base.h +++ b/mindspore/lite/src/runtime/kernel/arm/base/convolution_base.h @@ -48,7 +48,6 @@ class ConvolutionBaseCPUKernel : public LiteKernel { int Init() override; int ReSize() override { return 0; } int Run() override { return 0; } - virtual int CheckLayout(lite::Tensor *input_tensor); int SetIfAsymmetric(); int SetIfPerChannel(); int MallocQuantParam(); @@ -61,14 +60,12 @@ class ConvolutionBaseCPUKernel : public LiteKernel { void FreeQuantParam(); protected: - int tile_num_; void *bias_data_ = nullptr; - void *nhwc4_input_ = nullptr; const InnerContext *ctx_; - int thread_count_; ConvParameter *conv_param_; ConvQuantArg *conv_quant_arg_; - LayoutConvertor convert_func_ = nullptr; + int tile_num_; + int thread_count_; }; } // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_fp16.cc b/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_fp16.cc index 70655f23d5..9015aadd35 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_fp16.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_fp16.cc @@ -61,6 +61,10 @@ int ConvolutionFP16CPUKernel::InitWeightBias() { } memset(packed_weight_, 0, pack_weight_size * sizeof(float16_t)); RowMajor2Col8MajorFp16(execute_weight_, packed_weight_, out_channel, in_channel * kernel_plane, false); + if (fp16_weight_ != nullptr) { + free(fp16_weight_); + fp16_weight_ = nullptr; + } // init bias bias_data_ = malloc(oc8 * C8NUM * sizeof(float16_t)); diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/convolution.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/convolution.cc index 87c18fbcbf..9bc93235ba 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/convolution.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/convolution.cc @@ -17,6 +17,7 @@ #include "src/runtime/kernel/arm/fp32/convolution.h" #include "src/runtime/kernel/arm/fp32/convolution_1x1.h" #include "src/runtime/kernel/arm/fp32/convolution_winograd.h" +#include "src/runtime/kernel/arm/fp32/group_convolution.h" #include "nnacl/fp32/conv.h" #include "nnacl/common_func.h" #include "schema/model_generated.h" @@ -31,6 +32,7 @@ using mindspore::lite::RET_ERROR; using mindspore::lite::RET_INFER_INVALID; using mindspore::lite::RET_OK; using mindspore::schema::PrimitiveType_Conv2D; +using mindspore::schema::Format::Format_NHWC; namespace mindspore::kernel { int ConvolutionCPUKernel::InitWeightBias() { @@ -157,6 +159,108 @@ int ConvolutionCPUKernel::Run() { return RET_OK; } +kernel::LiteKernel *CpuConvFp32KernelSelect(const std::vector &inputs, + const std::vector &outputs, OpParameter *op_parameter, + const InnerContext *ctx, const mindspore::lite::PrimitiveC *primitive, + bool use_winograd, int out_unit) { + auto conv_param = reinterpret_cast(op_parameter); + if (conv_param->kernel_h_ == 1 && conv_param->kernel_w_ == 1) { + return new (std::nothrow) kernel::Convolution1x1CPUKernel(op_parameter, inputs, outputs, ctx, primitive); + } else if (use_winograd) { + return new (std::nothrow) + kernel::ConvolutionWinogradCPUKernel(op_parameter, inputs, outputs, ctx, primitive, out_unit); + } else { + return new (std::nothrow) kernel::ConvolutionCPUKernel(op_parameter, inputs, outputs, ctx, primitive); + } + return nullptr; +} + +kernel::LiteKernel *CpuGroupConvFp32KernelCreator(const std::vector &inputs, + const std::vector &outputs, OpParameter *op_parameter, + const InnerContext *ctx, const mindspore::lite::PrimitiveC *primitive, + int group) { + std::vector group_convs; + std::vector in_shape; + std::vector filter_shape; + std::vector bias_shape; + std::vector out_shape; + + auto conv_param = reinterpret_cast(op_parameter); + int out_channel = inputs.at(kWeightIndex)->Batch(); + int new_in_channel = inputs.at(kWeightIndex)->Channel(); + int new_out_channel = out_channel / group; + int kernel_h = conv_param->kernel_h_; + int kernel_w = conv_param->kernel_w_; + int input_num = inputs.size(); + int output_num = outputs.size(); + bool has_bias = input_num == 3; + bool use_winograd = false; + int out_unit; + + if (primitive != nullptr && primitive->GetInferFlag()) { + int batch = inputs.front()->Batch(); + int in_h = inputs.front()->Height(); + int in_w = inputs.front()->Width(); + conv_param->input_channel_ = new_in_channel; + conv_param->output_channel_ = new_out_channel; + CheckIfUseWinograd(&use_winograd, &out_unit, conv_param); + in_shape = {batch, in_h, in_w, new_in_channel}; + out_shape = {batch, conv_param->output_h_, conv_param->output_w_, new_out_channel}; + } + + filter_shape = {new_out_channel, kernel_h, kernel_w, new_in_channel}; + bias_shape = {new_out_channel}; + auto *origin_weight = reinterpret_cast(inputs.at(kWeightIndex)->data_c()); + auto *origin_bias = reinterpret_cast(inputs.at(kBiasIndex)->data_c()); + + for (int i = 0; i < group; ++i) { + std::vector new_inputs; + std::vector new_outputs; + // get new input for each group + auto in_tensor = + new (std::nothrow) lite::Tensor(inputs.front()->data_type(), in_shape, Format_NHWC, lite::Tensor::Category::VAR); + if (primitive != nullptr && primitive->GetInferFlag()) { + in_tensor->MallocData(); + } + new_inputs.emplace_back(in_tensor); + + // nwe weight + auto filter_tensor = new (std::nothrow) + lite::Tensor(inputs.at(kWeightIndex)->data_type(), filter_shape, Format_NHWC, lite::Tensor::Category::CONST); + filter_tensor->MallocData(); + int copy_length = kernel_h * kernel_w * new_in_channel * new_out_channel; + memcpy(filter_tensor->data_c(), origin_weight + i * copy_length, copy_length * sizeof(float)); + new_inputs.emplace_back(filter_tensor); + + // if has bias, set new bias + if (has_bias) { + auto bias_tensor = new (std::nothrow) + lite::Tensor(inputs.at(kBiasIndex)->data_type(), bias_shape, Format_NHWC, lite::Tensor::Category::CONST); + bias_tensor->MallocData(); + memcpy(bias_tensor->data_c(), origin_bias + i * new_out_channel, new_out_channel * sizeof(float)); + new_inputs.emplace_back(bias_tensor); + } + + // set new output tensor + for (int j = 0; j < output_num; ++j) { + auto tmp_out_tensor = new (std::nothrow) lite::Tensor(); + tmp_out_tensor->set_data_type(outputs.at(j)->data_type()); + tmp_out_tensor->SetFormat(outputs.at(j)->GetFormat()); + if (primitive != nullptr && primitive->GetInferFlag()) { + tmp_out_tensor->set_shape(out_shape); + tmp_out_tensor->MallocData(); + } + new_outputs.emplace_back(tmp_out_tensor); + } + + group_convs.emplace_back( + CpuConvFp32KernelSelect(new_inputs, new_outputs, op_parameter, ctx, primitive, use_winograd, out_unit)); + } + // sub kernels and group conv kernel share the same op_parameter struct + return new (std::nothrow) + GroupConvolutionCPUKernel(op_parameter, inputs, outputs, ctx, primitive, group_convs, group); +} + kernel::LiteKernel *CpuConvFp32KernelCreator(const std::vector &inputs, const std::vector &outputs, OpParameter *op_parameter, const InnerContext *ctx, const kernel::KernelKey &desc, @@ -164,8 +268,7 @@ kernel::LiteKernel *CpuConvFp32KernelCreator(const std::vector & MS_ASSERT(op_parameter != nullptr); MS_ASSERT(desc.type == schema::PrimitiveType_Conv2D); auto conv_param = reinterpret_cast(op_parameter); - int kernel_h = conv_param->kernel_h_; - int kernel_w = conv_param->kernel_w_; + int group = conv_param->group_; bool use_winograd = false; int out_unit; if (primitive != nullptr && primitive->GetInferFlag()) { @@ -192,14 +295,12 @@ kernel::LiteKernel *CpuConvFp32KernelCreator(const std::vector & } kernel::LiteKernel *kernel; - if (kernel_h == 1 && kernel_w == 1) { - kernel = new (std::nothrow) kernel::Convolution1x1CPUKernel(op_parameter, inputs, outputs, ctx, primitive); - } else if (use_winograd) { - kernel = - new (std::nothrow) kernel::ConvolutionWinogradCPUKernel(op_parameter, inputs, outputs, ctx, primitive, out_unit); + if (group == 1) { + kernel = CpuConvFp32KernelSelect(inputs, outputs, op_parameter, ctx, primitive, use_winograd, out_unit); } else { - kernel = new (std::nothrow) kernel::ConvolutionCPUKernel(op_parameter, inputs, outputs, ctx, primitive); + kernel = CpuGroupConvFp32KernelCreator(inputs, outputs, op_parameter, ctx, primitive, group); } + if (kernel == nullptr) { MS_LOG(ERROR) << "kernel is nullptr."; if (weight_tensor->data_type() == kNumberTypeInt8 || weight_tensor->data_type() == kNumberTypeInt16) { diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/group_convolution.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/group_convolution.cc new file mode 100644 index 0000000000..7754e5ba87 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/group_convolution.cc @@ -0,0 +1,150 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/fp32/group_convolution.h" +#include "schema/model_generated.h" +#include "src/kernel_registry.h" +#include "include/errorcode.h" + +using mindspore::kernel::KERNEL_ARCH::kCPU; +using mindspore::lite::KernelRegistrar; +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_OK; +using mindspore::schema::PrimitiveType_Conv2D; + +namespace mindspore::kernel { +int GroupConvolutionCPUKernel::Init() { + for (int i = 0; i < group_num_; ++i) { + auto ret = group_convs_[i]->Init(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Sub kernel init failed."; + return ret; + } + } + // if infer shape is done, resize func will be invoked in sub kernels + return RET_OK; +} + +int GroupConvolutionCPUKernel::ReSize() { + for (int i = 0; i < group_num_; ++i) { + auto ret = group_convs_[i]->ReSize(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Sub kernel resize failed."; + return RET_ERROR; + } + } + conv_param_->input_channel_ /= group_num_; + conv_param_->output_channel_ /= group_num_; + return RET_OK; +} + +int GroupConvolutionCPUKernel::PreProcess() { + if (!InferShapeDone()) { + auto ret = (const_cast(primitive_))->InferShape(in_tensors_, out_tensors_); + if (ret != 0) { + (const_cast(primitive_))->SetInferFlag(false); + MS_LOG(ERROR) << "InferShape fail!"; + return ret; + } + (const_cast(primitive_))->SetInferFlag(true); + ret = ReSize(); + if (ret != 0) { + MS_LOG(ERROR) << "ReSize fail!ret: " << ret; + return ret; + } + + // if infershape func is called in runtime stage, we should malloc memory and set shape info for outputs of sub + // kernels here. + std::vector in_shape; + std::vector out_shape; + for (int i = 0; i < group_num_; ++i) { + // in + int in_batch = conv_param_->input_batch_; + int in_h = conv_param_->input_h_; + int in_w = conv_param_->input_w_; + int in_c = conv_param_->input_channel_; + in_shape = {in_batch, in_h, in_w, in_c}; + auto sub_kernel_in_tensor = group_convs_[i]->in_tensors().front(); + sub_kernel_in_tensor->set_shape(in_shape); + sub_kernel_in_tensor->MallocData(); + // out + int out_batch = conv_param_->output_batch_; + int out_h = conv_param_->output_h_; + int out_w = conv_param_->output_w_; + int out_c = conv_param_->output_channel_; + out_shape = {out_batch, out_h, out_w, out_c}; + auto sub_kernel_out_tensors = group_convs_[i]->out_tensors(); + for (auto tensor : sub_kernel_out_tensors) { + tensor->set_shape(out_shape); + tensor->MallocData(); + } + } + } + + auto outputs = this->out_tensors(); + for (auto *output : outputs) { + MS_ASSERT(output != nullptr); + output->MallocData(); + } + return RET_OK; +} + +void GroupConvolutionCPUKernel::SeparateInput(int group_id) { + int in_h = conv_param_->input_h_; + int in_w = conv_param_->input_w_; + int in_plane = in_h * in_w; + int sub_in_channel = conv_param_->input_channel_; + int ori_in_channel = sub_in_channel * group_num_; + auto sub_in_data = reinterpret_cast(group_convs_[group_id]->in_tensors().front()->data_c()); + float *src_ptr = ori_in_data_ + group_id * sub_in_channel; + float *dst_ptr = sub_in_data; + for (int i = 0; i < in_plane; ++i) { + memcpy(dst_ptr, src_ptr, sub_in_channel * sizeof(float)); + src_ptr += ori_in_channel; + dst_ptr += sub_in_channel; + } +} + +void GroupConvolutionCPUKernel::PostConcat(int group_id) { + int out_h = conv_param_->output_h_; + int out_w = conv_param_->output_w_; + int out_plane = out_h * out_w; + int sub_out_channel = conv_param_->output_channel_; + int ori_out_channel = sub_out_channel * group_num_; + auto sub_out_data = reinterpret_cast(group_convs_[group_id]->out_tensors().front()->data_c()); + float *src_ptr = sub_out_data; + float *dst_ptr = ori_out_data_ + group_id * sub_out_channel; + for (int i = 0; i < out_plane; ++i) { + memcpy(dst_ptr, src_ptr, sub_out_channel * sizeof(float)); + src_ptr += sub_out_channel; + dst_ptr += ori_out_channel; + } +} + +int GroupConvolutionCPUKernel::Run() { + ori_in_data_ = reinterpret_cast(in_tensors().front()->data_c()); + ori_out_data_ = reinterpret_cast(out_tensors().front()->data_c()); + for (int i = 0; i < group_num_; ++i) { + // first, separate group conv input into several parts. This step must be in runtime stage. + SeparateInput(i); + // sun kernels run + group_convs_[i]->Run(); + // post process, concat all outputs of sub-kernels into one output + PostConcat(i); + } + return RET_OK; +} +} // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/group_convolution.h b/mindspore/lite/src/runtime/kernel/arm/fp32/group_convolution.h new file mode 100644 index 0000000000..3a9583c2f0 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/group_convolution.h @@ -0,0 +1,70 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GROUP_CONVOLUTION_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GROUP_CONVOLUTION_H_ + +#include +#include +#include "src/lite_kernel.h" +#include "nnacl/op_base.h" +#include "src/runtime/kernel/arm/base/convolution_base.h" +#include "nnacl/fp32/conv.h" + +namespace mindspore::kernel { +class GroupConvolutionCPUKernel : public ConvolutionBaseCPUKernel { + public: + GroupConvolutionCPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs, const lite::InnerContext *ctx, + const mindspore::lite::PrimitiveC *primitive, std::vector group_convs, + const int group_num) + : ConvolutionBaseCPUKernel(parameter, inputs, outputs, ctx, primitive), + group_convs_(std::move(group_convs)), + group_num_(group_num) {} // opParameter(in channel, out channel) in this kernel has been split to groups, if + // you want to get real params, multiply in channel / out channel with group num + ~GroupConvolutionCPUKernel() override { + for (auto sub_conv : group_convs_) { + // free sub conv input tensors / output tensors manually + auto sub_in_tensors = sub_conv->in_tensors(); + auto sub_in_tensor_num = sub_in_tensors.size(); + for (size_t i = 0; i < sub_in_tensor_num; ++i) { + delete sub_in_tensors[i]; + } + auto sub_out_tensors = sub_conv->out_tensors(); + auto sub_out_tensor_num = sub_out_tensors.size(); + for (size_t i = 0; i < sub_out_tensor_num; ++i) { + delete sub_out_tensors[i]; + } + delete sub_conv; + } + }; + + int Init() override; + int ReSize() override; + int Run() override; + int PreProcess() override; + void SeparateInput(int group_id); + void PostConcat(int group_id); + + private: + std::vector group_convs_; + float *ori_in_data_ = nullptr; // do not free + float *ori_out_data_ = nullptr; // do not free + const int group_num_; +}; +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GROUP_CONVOLUTION_H_