From 5358d552a1360bb760e7e5d4a33092909738794c Mon Sep 17 00:00:00 2001 From: fuzhiye Date: Fri, 8 Jan 2021 11:29:23 +0800 Subject: [PATCH] dispatch convolution kernel through delegate --- mindspore/lite/nnacl/int8/quantize.h | 3 - .../kernel/arm/fp16/convolution_1x1_fp16.cc | 17 +- .../kernel/arm/fp16/convolution_1x1_fp16.h | 9 +- .../arm/fp16/convolution_delegate_fp16.cc | 418 ++++++++++++++++++ .../arm/fp16/convolution_delegate_fp16.h | 65 +++ .../kernel/arm/fp16/convolution_fp16.cc | 345 +-------------- .../kernel/arm/fp16/convolution_fp16.h | 12 +- .../arm/fp16/convolution_winograd_fp16.cc | 33 +- .../arm/fp16/convolution_winograd_fp16.h | 14 +- .../kernel/arm/fp16/group_convolution_fp16.cc | 20 +- .../src/runtime/kernel/arm/fp32/adder_fp32.cc | 27 ++ .../src/runtime/kernel/arm/fp32/adder_fp32.h | 4 +- .../kernel/arm/fp32/convolution_1x1_fp32.cc | 27 +- .../kernel/arm/fp32/convolution_1x1_fp32.h | 8 +- .../arm/fp32/convolution_delegate_fp32.cc | 416 +++++++++++++++++ .../arm/fp32/convolution_delegate_fp32.h | 77 ++++ .../kernel/arm/fp32/convolution_fp32.cc | 329 +------------- .../kernel/arm/fp32/convolution_fp32.h | 18 +- .../arm/fp32/convolution_winograd_fp32.cc | 22 +- .../arm/fp32/convolution_winograd_fp32.h | 8 +- .../kernel/arm/fp32/group_convolution_fp32.cc | 20 +- .../kernel/arm/int8/convolution_int8.cc | 3 +- .../kernel/arm/fp32/conv1x1_fp32_tests.cc | 142 ------ 23 files changed, 1108 insertions(+), 929 deletions(-) create mode 100644 mindspore/lite/src/runtime/kernel/arm/fp16/convolution_delegate_fp16.cc create mode 100644 mindspore/lite/src/runtime/kernel/arm/fp16/convolution_delegate_fp16.h create mode 100644 mindspore/lite/src/runtime/kernel/arm/fp32/convolution_delegate_fp32.cc create mode 100644 mindspore/lite/src/runtime/kernel/arm/fp32/convolution_delegate_fp32.h diff --git a/mindspore/lite/nnacl/int8/quantize.h b/mindspore/lite/nnacl/int8/quantize.h index 72595b4cfc..f93e6b401e 100644 --- a/mindspore/lite/nnacl/int8/quantize.h +++ b/mindspore/lite/nnacl/int8/quantize.h @@ -21,9 +21,6 @@ #include #include "nnacl/op_base.h" -#define INPUT_ASYMMETRIC 0b001 -#define FILTER_ASYMMETRIC 0b010 -#define OUTPUT_ASYMMETRIC 0b100 #define INPUT_PER_CHANNEL 0b001 #define FILTER_PER_CHANNEL 0b010 #define OUTPUT_PER_CHANNEL 0b100 diff --git a/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_1x1_fp16.cc b/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_1x1_fp16.cc index d3b93176bf..0f93bdb4f0 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_1x1_fp16.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_1x1_fp16.cc @@ -93,13 +93,7 @@ int Convolution1x1FP16CPUKernel::InitWeightBias() { MS_LOG(ERROR) << "Conv1x1 Malloc bias_ptr_ error!"; return RET_ERROR; } - auto bias_tensor = in_tensors_.at(kBiasIndex); - if (bias_tensor->data_type() == kNumberTypeFloat16) { - memcpy(bias_data_, bias_tensor->MutableData(), output_channel * sizeof(float16_t)); - } else { - Float32ToFloat16(reinterpret_cast(bias_tensor->MutableData()), reinterpret_cast(bias_data_), - output_channel); - } + memcpy(bias_data_, fp16_bias_, output_channel * sizeof(float16_t)); memset(reinterpret_cast(bias_data_) + weight_size, 0, size - weight_size); } @@ -111,8 +105,7 @@ int Convolution1x1FP16CPUKernel::InitWeightBias() { return RET_ERROR; } memset(reinterpret_cast(weight_ptr_) + down_size, 0, size - down_size); - ColMajor2Row8MajorFp16(weight_tensor->MutableData(), weight_ptr_, input_channel, output_channel, - weight_tensor->data_type() == kNumberTypeFloat16); + ColMajor2Row8MajorFp16(fp16_weight_, weight_ptr_, input_channel, output_channel, true); return RET_OK; } @@ -127,10 +120,7 @@ int Convolution1x1FP16CPUKernel::Init() { MS_LOG(ERROR) << "Init weight bias failed."; return ret; } - if (!InferShapeDone()) { - return RET_OK; - } - return ReSize(); + return RET_OK; } void Convolution1x1FP16CPUKernel::FreeTmpBuffer() { @@ -143,7 +133,6 @@ void Convolution1x1FP16CPUKernel::FreeTmpBuffer() { int Convolution1x1FP16CPUKernel::ReSize() { FreeTmpBuffer(); - auto ret = ConvolutionBaseCPUKernel::Init(); if (ret != RET_OK) { MS_LOG(ERROR) << "ConvolutionBase init failed."; diff --git a/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_1x1_fp16.h b/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_1x1_fp16.h index 78b3c95a41..ba31a28803 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_1x1_fp16.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_1x1_fp16.h @@ -30,8 +30,11 @@ class Convolution1x1FP16CPUKernel : public ConvolutionBaseFP16CPUKernel { public: Convolution1x1FP16CPUKernel(OpParameter *parameter, const std::vector &inputs, const std::vector &outputs, const InnerContext *ctx, - const mindspore::lite::PrimitiveC *primitive) - : ConvolutionBaseFP16CPUKernel(parameter, inputs, outputs, ctx, primitive) {} + const mindspore::lite::PrimitiveC *primitive, float16_t *fp16_weight, + float16_t *fp16_bias) + : ConvolutionBaseFP16CPUKernel(parameter, inputs, outputs, ctx, primitive), + fp16_weight_(fp16_weight), + fp16_bias_(fp16_bias) {} ~Convolution1x1FP16CPUKernel() override; int Init() override; @@ -53,6 +56,8 @@ class Convolution1x1FP16CPUKernel : public ConvolutionBaseFP16CPUKernel { bool multi_thread_by_hw_ = false; int thread_count_ = 1; int thread_stride_ = 0; + float16_t *fp16_weight_; // do not free + float16_t *fp16_bias_; // do not free float16_t *weight_ptr_ = nullptr; float16_t *input_ptr_ = nullptr; float16_t *pack_input_ = nullptr; diff --git a/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_delegate_fp16.cc b/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_delegate_fp16.cc new file mode 100644 index 0000000000..ad25700d9a --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_delegate_fp16.cc @@ -0,0 +1,418 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/fp16/convolution_delegate_fp16.h" +#include +#include "src/runtime/kernel/arm/fp32/convolution_delegate_fp32.h" +#include "src/runtime/kernel/arm/fp16/convolution_fp16.h" +#include "src/runtime/kernel/arm/fp16/convolution_winograd_fp16.h" +#include "src/runtime/kernel/arm/fp16/convolution_1x1_fp16.h" +#include "src/runtime/kernel/arm/fp16/group_convolution_fp16.h" +#include "schema/model_generated.h" +#include "src/kernel_registry.h" +#include "include/errorcode.h" +#include "src/runtime/runtime_api.h" +#include "src/runtime/kernel/arm/base/dequant.h" + +using mindspore::kernel::KERNEL_ARCH::kCPU; +using mindspore::lite::KernelRegistrar; +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_OK; +using mindspore::schema::PrimitiveType_Conv2D; +using mindspore::schema::Format::Format_NHWC; + +namespace mindspore::kernel { +void ConvolutionDelegateFP16CPUKernel::FreeCopiedData() { + if ((fp16_weight_ != nullptr) && (need_free_ & WEIGHT_NEED_FREE)) { + free(fp16_weight_); + fp16_weight_ = nullptr; + } + if ((fp16_bias_ != nullptr) && (need_free_ & BIAS_NEED_FREE)) { + free(fp16_bias_); + fp16_bias_ = nullptr; + } +} + +int ConvolutionDelegateFP16CPUKernel::GetFp16WeightAndBias() { + auto ret = GetFp16Weight(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Get Fp16 Weight failed."; + return RET_ERROR; + } + + ret = GetFp16Bias(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Get Fp16 Bias failed."; + return RET_ERROR; + } + return RET_OK; +} + +int ConvolutionDelegateFP16CPUKernel::GetFp16Weight() { + auto weight_tensor = in_tensors_.at(kWeightIndex); + if (weight_tensor->data_type() == kNumberTypeFloat16 && InferShapeDone()) { + // do not need malloc new memory to store origin data + fp16_weight_ = reinterpret_cast(weight_tensor->data_c()); + return RET_OK; + } else { + fp16_weight_ = CopyData(weight_tensor); + if (fp16_weight_ == nullptr) { + MS_LOG(ERROR) << "Generate fp16_weight failed."; + return RET_ERROR; + } + need_free_ = need_free_ | WEIGHT_NEED_FREE; + return RET_OK; + } + return RET_OK; +} + +int ConvolutionDelegateFP16CPUKernel::GetFp16Bias() { + if (in_tensors_.size() == 3) { + // has bias situation + auto bias_tensor = in_tensors_.at(kBiasIndex); + if (bias_tensor->data_type() == kNumberTypeFloat16 && InferShapeDone()) { + // do not need malloc new memory to store origin data + fp16_bias_ = reinterpret_cast(bias_tensor->data_c()); + return RET_OK; + } else { + fp16_bias_ = CopyData(bias_tensor); + if (fp16_bias_ == nullptr) { + MS_LOG(ERROR) << "Generate fp16_bias failed."; + return RET_ERROR; + } + need_free_ = need_free_ | BIAS_NEED_FREE; + return RET_OK; + } + } + return RET_OK; +} + +float16_t *ConvolutionDelegateFP16CPUKernel::CopyData(lite::Tensor *tensor) { + auto data_type = tensor->data_type(); + MS_ASSERT(data_type == kNumberTypeFloat32 || data_type == kNumberTypeFloat16); + auto fp16_data = reinterpret_cast(malloc(tensor->ElementsNum() * sizeof(float16_t))); + if (fp16_data == nullptr) { + MS_LOG(ERROR) << "Malloc fp16_data failed."; + return nullptr; + } + if (data_type == kNumberTypeFloat32) { + float *origin_data = reinterpret_cast(tensor->data_c()); + for (size_t i = 0; i < tensor->ElementsNum(); ++i) { + fp16_data[i] = (float16_t)origin_data[i]; + } + } else { + auto *origin_data = reinterpret_cast(tensor->data_c()); + memcpy(fp16_data, origin_data, tensor->Size()); + } + return fp16_data; +} + +int ConvolutionDelegateFP16CPUKernel::Init() { + auto ret = GetFp16WeightAndBias(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Get fp16 weight and bias failed."; + return ret; + } + if (!InferShapeDone()) { + return RET_OK; + } + return ReSize(); +} + +int ConvolutionDelegateFP16CPUKernel::ReSize() { + // Update shape info of input and output + SetInputOutputShapeInfo(reinterpret_cast(op_parameter_), in_tensors_.front(), out_tensors_.front(), + context_); + if (fp16_conv_kernel_ == nullptr) { + fp16_conv_kernel_ = + CpuConvFp16KernelSelect(in_tensors_, out_tensors_, op_parameter_, context_, primitive_, fp16_weight_, fp16_bias_); + if (fp16_conv_kernel_ == nullptr) { + MS_LOG(ERROR) << "Selecting execute kernel failed for conv_kernel, got a nullptr."; + return RET_ERROR; + } + } + // copied weight and bias are not be used anymore,free them. + FreeCopiedData(); + return fp16_conv_kernel_->ReSize(); +} + +ConvParameter *CreateNewConvParameterFp16(ConvParameter *parameter) { + auto conv_parameter = reinterpret_cast(malloc(sizeof(ConvParameter))); + if (conv_parameter == nullptr) { + MS_LOG(ERROR) << "Malloc new conv parameter failed."; + return nullptr; + } + memcpy(conv_parameter, parameter, sizeof(ConvParameter)); + return conv_parameter; +} + +kernel::LiteKernel *CpuConvFp16KernelSelect(const std::vector &inputs, + const std::vector &outputs, OpParameter *op_parameter, + const lite::InnerContext *ctx, const mindspore::lite::PrimitiveC *primitive, + float16_t *fp16_weight, float16_t *fp16_bias) { + auto conv_param = reinterpret_cast(op_parameter); + bool use_winograd = false; + int out_unit; + CheckIfUseWinogradFp16(&use_winograd, &out_unit, conv_param); + kernel::LiteKernel *kernel = nullptr; + if (conv_param->kernel_h_ == 1 && conv_param->kernel_w_ == 1) { + kernel = new (std::nothrow) + kernel::Convolution1x1FP16CPUKernel(op_parameter, inputs, outputs, ctx, primitive, fp16_weight, fp16_bias); + } else if (use_winograd) { + kernel = new (std::nothrow) kernel::ConvolutionWinogradFP16CPUKernel(op_parameter, inputs, outputs, ctx, primitive, + out_unit, fp16_weight, fp16_bias); + } else { + kernel = new (std::nothrow) + kernel::ConvolutionFP16CPUKernel(op_parameter, inputs, outputs, ctx, primitive, fp16_weight, fp16_bias); + } + // Once kernel is selected, init func will invoke InitWeightAndBias + auto ret = kernel->Init(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "kernel init failed."; + delete kernel; + return nullptr; + } + return kernel; +} + +void FreeMemoryFp16(const std::vector &group_convs, const std::vector &new_inputs, + const std::vector &new_outputs) { + for (auto sub_conv : group_convs) { + delete sub_conv; + } + for (auto in_tensor : new_inputs) { + delete in_tensor; + } + for (auto out_tensor : new_outputs) { + delete out_tensor; + } +} + +static lite::Tensor *CreateInputTensorFp16(TypeId data_type, const std::vector &in_shape, bool infered_flag) { + auto in_tensor = new (std::nothrow) lite::Tensor(data_type, in_shape, Format_NHWC, lite::Tensor::Category::VAR); + if (in_tensor == nullptr) { + MS_LOG(ERROR) << "new in_tensor failed."; + return nullptr; + } + if (infered_flag) { + auto ret = in_tensor->MallocData(); + if (ret != RET_OK) { + delete in_tensor; + MS_LOG(ERROR) << "in tensor malloc failed."; + return nullptr; + } + } + return in_tensor; +} + +static lite::Tensor *CreateConstTensorFp16(lite::Tensor *tensor, const std::vector &shape, const int index) { + auto new_tensor = + new (std::nothrow) lite::Tensor(tensor->data_type(), shape, Format_NHWC, lite::Tensor::Category::CONST_TENSOR); + if (new_tensor == nullptr) { + MS_LOG(ERROR) << "Create new_tensor failed."; + return nullptr; + } + auto ret = new_tensor->MallocData(); + if (ret != RET_OK) { + delete new_tensor; + MS_LOG(ERROR) << "Malloc new_tensor failed."; + return nullptr; + } + memcpy(new_tensor->data_c(), reinterpret_cast(tensor->data_c()) + index * new_tensor->Size(), + new_tensor->Size()); + return new_tensor; +} + +static lite::Tensor *CreateOutputTensorFp16(const std::vector &out_shape, + const std::vector &outputs, bool infered_flag, int index) { + auto out_tensor = new (std::nothrow) lite::Tensor(); + if (out_tensor == nullptr) { + MS_LOG(ERROR) << "new tmp_out_tensor failed."; + return nullptr; + } + out_tensor->set_data_type(mindspore::kNumberTypeFloat16); + out_tensor->set_format(outputs.at(index)->format()); + if (infered_flag) { + out_tensor->set_shape(out_shape); + auto ret = out_tensor->MallocData(); + if (ret != RET_OK) { + delete out_tensor; + MS_LOG(ERROR) << "out_tensor malloc data failed."; + return nullptr; + } + } + return out_tensor; +} + +kernel::LiteKernel *CreateDelegateConvFp16(const std::vector &inputs, + const std::vector &outputs, OpParameter *op_parameter, + const InnerContext *ctx, const mindspore::lite::PrimitiveC *primitive) { + return new (std::nothrow) kernel::ConvolutionDelegateFP16CPUKernel(op_parameter, inputs, outputs, ctx, primitive); +} + +kernel::LiteKernel *CpuGroupConvFp16KernelCreator(const std::vector &inputs, + const std::vector &outputs, OpParameter *op_parameter, + const InnerContext *ctx, + const mindspore::lite::PrimitiveC *primitive) { + bool infer_flag = (primitive != nullptr && primitive->infer_flag()); + auto conv_param = reinterpret_cast(op_parameter); + // update new shape info for each sub kernel + int new_in_channel = inputs.at(kWeightIndex)->Channel(); + int new_out_channel = 0; + if (conv_param->group_ == 0) { + MS_LOG(ERROR) << "Divisor 'group' cannot be 0."; + return nullptr; + } else { + new_out_channel = inputs.at(kWeightIndex)->Batch() / conv_param->group_; + } + + std::vector in_shape; + std::vector out_shape; + if (infer_flag) { + conv_param->input_channel_ = new_in_channel; + conv_param->output_channel_ = new_out_channel; + in_shape = {inputs.front()->Batch(), inputs.front()->Height(), inputs.front()->Width(), new_in_channel}; + out_shape = {inputs.front()->Batch(), outputs.front()->Height(), outputs.front()->Width(), new_out_channel}; + } + std::vector filter_shape = {new_out_channel, conv_param->kernel_h_, conv_param->kernel_w_, new_in_channel}; + std::vector bias_shape = {new_out_channel}; + + // new group conv op + std::vector group_convs; + // create tensors for every sub conv kernel + for (int i = 0; i < conv_param->group_; ++i) { + std::vector new_inputs; + std::vector new_outputs; + auto new_conv_parameter = CreateNewConvParameterFp16(conv_param); + if (new_conv_parameter == nullptr) { + FreeMemoryFp16(group_convs, new_inputs, new_outputs); + MS_LOG(ERROR) << "Get new conv parameter failed."; + return nullptr; + } + // create new input for each group + auto in_tensor = CreateInputTensorFp16(mindspore::kNumberTypeFloat16, in_shape, infer_flag); + if (in_tensor == nullptr) { + delete new_conv_parameter; + FreeMemoryFp16(group_convs, new_inputs, new_outputs); + MS_LOG(ERROR) << "create input tensor failed."; + return nullptr; + } + new_inputs.emplace_back(in_tensor); + + // create new weight + auto filter_tensor = CreateConstTensorFp16(inputs.at(kWeightIndex), filter_shape, i); + if (filter_tensor == nullptr) { + delete new_conv_parameter; + FreeMemoryFp16(group_convs, new_inputs, new_outputs); + MS_LOG(ERROR) << "create filter tensor failed."; + return nullptr; + } + new_inputs.emplace_back(filter_tensor); + + // if has bias, create new bias + if (inputs.size() == 3) { + auto bias_tensor = CreateConstTensorFp16(inputs.at(kBiasIndex), bias_shape, i); + if (bias_tensor == nullptr) { + delete new_conv_parameter; + FreeMemoryFp16(group_convs, new_inputs, new_outputs); + MS_LOG(ERROR) << "create bias_tensor failed."; + return nullptr; + } + new_inputs.emplace_back(bias_tensor); + } + + // create new output tensors + for (size_t j = 0; j < outputs.size(); ++j) { + auto out_tensor = CreateOutputTensorFp16(out_shape, outputs, infer_flag, j); + if (out_tensor == nullptr) { + delete new_conv_parameter; + FreeMemoryFp16(group_convs, new_inputs, new_outputs); + MS_LOG(ERROR) << "new out_tensor failed."; + return nullptr; + } + new_outputs.emplace_back(out_tensor); + } + group_convs.emplace_back(CreateDelegateConvFp16( + new_inputs, new_outputs, reinterpret_cast(new_conv_parameter), ctx, primitive)); + } + return new (std::nothrow) + GroupConvolutionFP16CPUKernel(op_parameter, inputs, outputs, ctx, primitive, group_convs, conv_param->group_); +} + +kernel::LiteKernel *CpuConvFp16KernelCreator(const std::vector &inputs, + const std::vector &outputs, OpParameter *opParameter, + const InnerContext *ctx, const kernel::KernelKey &desc, + const mindspore::lite::PrimitiveC *primitive) { + MS_ASSERT(opParameter != nullptr); + MS_ASSERT(desc.type == schema::PrimitiveType_Conv2D); + + auto *weight_tensor = inputs.at(kWeightIndex); + auto *restore_data = weight_tensor->data_c(); + auto restore_type = weight_tensor->data_type(); + bool dequant_flag = + !weight_tensor->quant_params().empty() && weight_tensor->quant_params().front().inited && restore_data != nullptr; + if (dequant_flag) { + auto *dequant_weight = kernel::DequantUtil::DequantWeight(weight_tensor); + if (dequant_weight == nullptr) { + MS_LOG(ERROR) << "dequant data is nullptr."; + free(opParameter); + return nullptr; + } + weight_tensor->set_data_type(kNumberTypeFloat32); + weight_tensor->set_data(dequant_weight); + } + + auto conv_param = reinterpret_cast(opParameter); + kernel::LiteKernel *kernel = nullptr; + if (conv_param->group_ == 1) { + kernel = CreateDelegateConvFp16(inputs, outputs, opParameter, ctx, primitive); + } else { + kernel = CpuGroupConvFp16KernelCreator(inputs, outputs, opParameter, ctx, primitive); + } + + if (kernel == nullptr) { + MS_LOG(DEBUG) << "Create conv fp16 kernel failed."; + if (dequant_flag) { + weight_tensor->FreeData(); + weight_tensor->set_data(restore_data); + weight_tensor->set_data_type(restore_type); + } + free(opParameter); + return nullptr; + } + + auto ret = kernel->Init(); + if (ret != RET_OK) { + MS_LOG(INFO) << "Init fp16 kernel failed, name: " << opParameter->name_ + << ", type: " << schema::EnumNamePrimitiveType(static_cast(opParameter->type_)); + if (dequant_flag) { + weight_tensor->FreeData(); + weight_tensor->set_data(restore_data); + weight_tensor->set_data_type(restore_type); + } + delete kernel; + return nullptr; + } + + if (dequant_flag) { + weight_tensor->FreeData(); + weight_tensor->set_data(restore_data); + weight_tensor->set_data_type(restore_type); + } + return kernel; +} +REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_Conv2D, CpuConvFp16KernelCreator) +} // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_delegate_fp16.h b/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_delegate_fp16.h new file mode 100644 index 0000000000..44b44161ea --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_delegate_fp16.h @@ -0,0 +1,65 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_CONVOLUTION_DELEGATE_FP16_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_CONVOLUTION_DELEGATE_FP16_H_ + +#include +#include +#include "src/lite_kernel.h" +#include "nnacl/conv_parameter.h" +#include "nnacl/op_base.h" + +#define WEIGHT_NEED_FREE 0b01 +#define BIAS_NEED_FREE 0b10 + +namespace mindspore::kernel { +class ConvolutionDelegateFP16CPUKernel : public LiteKernel { + public: + ConvolutionDelegateFP16CPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs, const lite::InnerContext *ctx, + const mindspore::lite::PrimitiveC *primitive) + : LiteKernel(parameter, inputs, outputs, ctx, primitive) {} + ~ConvolutionDelegateFP16CPUKernel() override { + FreeCopiedData(); + if (fp16_conv_kernel_ != nullptr) { + op_parameter_ = nullptr; // set op_parameter of delegate to nullptr, avoiding double free + delete fp16_conv_kernel_; + fp16_conv_kernel_ = nullptr; + } + } + int GetFp16WeightAndBias(); + int GetFp16Weight(); + int GetFp16Bias(); + float16_t *CopyData(lite::Tensor *tensor); + void FreeCopiedData(); + int Init() override; + int ReSize() override; + int Run() override { return fp16_conv_kernel_->Run(); } + + private: + uint8_t need_free_ = 0b00; + kernel::LiteKernel *fp16_conv_kernel_ = nullptr; + float16_t *fp16_weight_ = nullptr; + float16_t *fp16_bias_ = nullptr; +}; + +kernel::LiteKernel *CpuConvFp16KernelSelect(const std::vector &inputs, + const std::vector &outputs, OpParameter *op_parameter, + const lite::InnerContext *ctx, const mindspore::lite::PrimitiveC *primitive, + float16_t *fp16_weight, float16_t *fp16_bias); +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_CONVOLUTION_DELEGATE_FP16_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_fp16.cc b/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_fp16.cc index 88e1b21a0d..6e28e19a30 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_fp16.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_fp16.cc @@ -16,19 +16,16 @@ #include "src/runtime/kernel/arm/fp16/convolution_fp16.h" #include -#include "src/runtime/kernel/arm/fp16/convolution_winograd_fp16.h" -#include "src/runtime/kernel/arm/fp16/convolution_1x1_fp16.h" -#include "src/runtime/kernel/arm/fp16/group_convolution_fp16.h" -#include "nnacl/fp16/conv_fp16.h" -#include "nnacl/fp16/cast_fp16.h" -#include "nnacl/fp16/pack_fp16.h" -#include "src/runtime/kernel/arm/fp16/layout_transform_fp16.h" +#include "include/errorcode.h" #include "schema/model_generated.h" #include "src/kernel_registry.h" -#include "include/errorcode.h" #include "src/runtime/runtime_api.h" -#include "nnacl/fp16/winograd_utils_fp16.h" #include "src/runtime/kernel/arm/base/dequant.h" +#include "nnacl/fp16/conv_fp16.h" +#include "nnacl/fp16/matmul_fp16.h" +#include "nnacl/fp16/cast_fp16.h" +#include "nnacl/fp16/pack_fp16.h" +#include "nnacl/fp16/winograd_utils_fp16.h" using mindspore::kernel::KERNEL_ARCH::kCPU; using mindspore::lite::KernelRegistrar; @@ -49,23 +46,13 @@ int ConvolutionFP16CPUKernel::InitWeightBias() { int pack_weight_size = oc8 * in_channel * kernel_plane; // init weight - auto ret = ConvolutionBaseFP16CPUKernel::GetExecuteFilter(); - if (ret != RET_OK) { - MS_LOG(ERROR) << "Get Execute filter failed."; - return ret; - } packed_weight_ = reinterpret_cast(malloc(pack_weight_size * sizeof(float16_t))); if (packed_weight_ == nullptr) { MS_LOG(ERROR) << "malloc packed_weight_ failed."; return RET_ERROR; } memset(packed_weight_, 0, pack_weight_size * sizeof(float16_t)); - RowMajor2Col8MajorFp16(execute_weight_, packed_weight_, out_channel, in_channel * kernel_plane, false); - if (fp16_weight_ != nullptr) { - free(fp16_weight_); - fp16_weight_ = nullptr; - execute_weight_ = nullptr; - } + RowMajor2Col8MajorFp16(fp16_weight_, packed_weight_, out_channel, in_channel * kernel_plane, false); // init bias bias_data_ = malloc(oc8 * sizeof(float16_t)); @@ -74,12 +61,9 @@ int ConvolutionFP16CPUKernel::InitWeightBias() { return RET_ERROR; } memset(bias_data_, 0, oc8 * sizeof(float16_t)); - auto fp16_bias_data = reinterpret_cast(bias_data_); if (in_tensors_.size() == kInputSize2) { - auto ori_bias = reinterpret_cast(in_tensors_.at(kBiasIndex)->data_c()); - for (int i = 0; i < out_channel; ++i) { - fp16_bias_data[i] = (float16_t)ori_bias[i]; - } + auto fp16_bias_data = reinterpret_cast(bias_data_); + memcpy(fp16_bias_data, fp16_bias_, out_channel * sizeof(float16_t)); } else { MS_ASSERT(in_tensors_.size() == kInputSize1); } @@ -111,10 +95,7 @@ int ConvolutionFP16CPUKernel::Init() { MS_LOG(ERROR) << "Init weight bias failed."; return RET_ERROR; } - if (!InferShapeDone()) { - return RET_OK; - } - return ReSize(); + return RET_OK; } int ConvolutionFP16CPUKernel::ReSize() { @@ -123,7 +104,6 @@ int ConvolutionFP16CPUKernel::ReSize() { MS_LOG(ERROR) << "Resize is invalid."; return ret; } - ret = ConvolutionBaseCPUKernel::Init(); if (ret != RET_OK) { MS_LOG(ERROR) << "ConvolutionBase init fail!ret: " << ret; @@ -173,309 +153,4 @@ int ConvolutionFP16CPUKernel::Run() { FreeTmpBuffer(); return ret; } - -ConvParameter *CreateNewConvParameterFp16(ConvParameter *parameter) { - auto conv_parameter = reinterpret_cast(malloc(sizeof(ConvParameter))); - if (conv_parameter == nullptr) { - MS_LOG(ERROR) << "Malloc new conv parameter failed."; - return nullptr; - } - memcpy(conv_parameter, parameter, sizeof(ConvParameter)); - return conv_parameter; -} - -kernel::LiteKernel *CpuConvFp16KernelSelect(const std::vector &inputs, - const std::vector &outputs, OpParameter *op_parameter, - const InnerContext *ctx, const mindspore::lite::PrimitiveC *primitive, - bool use_winograd, int out_unit) { - auto conv_param = reinterpret_cast(op_parameter); - if (conv_param->kernel_h_ == 1 && conv_param->kernel_w_ == 1) { - return new (std::nothrow) kernel::Convolution1x1FP16CPUKernel(op_parameter, inputs, outputs, ctx, primitive); - } else if (use_winograd) { - return new (std::nothrow) - kernel::ConvolutionWinogradFP16CPUKernel(op_parameter, inputs, outputs, ctx, primitive, out_unit); - } else { - return new (std::nothrow) kernel::ConvolutionFP16CPUKernel(op_parameter, inputs, outputs, ctx, primitive); - } - return nullptr; -} - -void FreeMemoryFp16(const std::vector &group_convs, const std::vector &new_inputs, - const std::vector &new_outputs) { - for (auto sub_conv : group_convs) { - delete sub_conv; - } - for (auto in_tensor : new_inputs) { - delete in_tensor; - } - for (auto out_tensor : new_outputs) { - delete out_tensor; - } -} - -lite::Tensor *CreateInputTensorFp16(TypeId data_type, std::vector in_shape, bool infered_flag) { - auto in_tensor = new (std::nothrow) lite::Tensor(data_type, in_shape, Format_NHWC, lite::Tensor::Category::VAR); - if (in_tensor == nullptr) { - MS_LOG(ERROR) << "new in_tensor failed."; - return nullptr; - } - if (infered_flag) { - auto ret = in_tensor->MallocData(); - if (ret != RET_OK) { - delete in_tensor; - MS_LOG(ERROR) << "in tensor malloc failed."; - return nullptr; - } - } - return in_tensor; -} - -lite::Tensor *CreateFilterTensorFp16(TypeId data_type, std::vector filter_shape, - const std::vector &inputs, int copy_length, int index) { - auto filter_tensor = - new (std::nothrow) lite::Tensor(data_type, filter_shape, Format_NHWC, lite::Tensor::Category::CONST_TENSOR); - if (filter_tensor == nullptr) { - MS_LOG(ERROR) << "new filter_tensor failed."; - return nullptr; - } - auto ret = filter_tensor->MallocData(); - if (ret != RET_OK) { - delete filter_tensor; - MS_LOG(ERROR) << "filter_tensor malloc failed."; - return nullptr; - } - if (data_type == kNumberTypeFloat16) { - auto *origin_weight = reinterpret_cast(inputs.at(kWeightIndex)->data_c()); - memcpy(filter_tensor->data_c(), origin_weight + index * copy_length, copy_length * sizeof(float16_t)); - } else { - MS_ASSERT(data_type == kNumberTypeFloat32); - auto *origin_weight = reinterpret_cast(inputs.at(kWeightIndex)->data_c()); - memcpy(filter_tensor->data_c(), origin_weight + index * copy_length, copy_length * sizeof(float)); - } - return filter_tensor; -} - -lite::Tensor *CreateBiasTensorFp16(TypeId data_type, std::vector bias_shape, - const std::vector &inputs, int new_out_channel, int index) { - auto *origin_bias = inputs.at(kBiasIndex)->data_c(); - auto bias_tensor = - new (std::nothrow) lite::Tensor(data_type, bias_shape, Format_NHWC, lite::Tensor::Category::CONST_TENSOR); - if (bias_tensor == nullptr) { - MS_LOG(ERROR) << "new bias_tensor failed."; - return nullptr; - } - auto ret = bias_tensor->MallocData(); - if (ret != RET_OK) { - delete bias_tensor; - MS_LOG(ERROR) << "bias_tensor malloc failed."; - return nullptr; - } - if (data_type == kNumberTypeFloat16) { - auto bias_data = reinterpret_cast(origin_bias); - memcpy(bias_tensor->data_c(), bias_data + index * new_out_channel, new_out_channel * sizeof(float16_t)); - } else { - MS_ASSERT(data_type == kNumberTypeFloat32); - auto bias_data = reinterpret_cast(origin_bias); - memcpy(bias_tensor->data_c(), bias_data + index * new_out_channel, new_out_channel * sizeof(float)); - } - return bias_tensor; -} - -lite::Tensor *CreateOutputTensorFp16(std::vector out_shape, const std::vector &outputs, - bool infered_flag, int index) { - auto out_tensor = new (std::nothrow) lite::Tensor(); - if (out_tensor == nullptr) { - MS_LOG(ERROR) << "new tmp_out_tensor failed."; - return nullptr; - } - out_tensor->set_data_type(mindspore::kNumberTypeFloat16); - out_tensor->set_format(outputs.at(index)->format()); - if (infered_flag) { - out_tensor->set_shape(out_shape); - auto ret = out_tensor->MallocData(); - if (ret != RET_OK) { - delete out_tensor; - MS_LOG(ERROR) << "out_tensor malloc data failed."; - return nullptr; - } - } - return out_tensor; -} - -kernel::LiteKernel *CpuGroupConvFp16KernelCreator(const std::vector &inputs, - const std::vector &outputs, OpParameter *op_parameter, - const InnerContext *ctx, const mindspore::lite::PrimitiveC *primitive, - int group) { - int out_unit; - bool has_bias = inputs.size() == 3; - bool use_winograd = false; - bool infered_flag = (primitive != nullptr && primitive->infer_flag()); - auto conv_param = reinterpret_cast(op_parameter); - - // update new shape info for each sub kernel - int new_in_channel = inputs.at(kWeightIndex)->Channel(); - int new_out_channel = 0; - if (group == 0) { - MS_LOG(ERROR) << "Divisor 'group' cannot be 0."; - return nullptr; - } else { - new_out_channel = inputs.at(kWeightIndex)->Batch() / group; - } - - std::vector in_shape; - std::vector out_shape; - int batch = inputs.front()->Batch(); - conv_param->input_batch_ = batch; - conv_param->output_batch_ = batch; - if (infered_flag) { - conv_param->input_channel_ = new_in_channel; - conv_param->output_channel_ = new_out_channel; - CheckIfUseWinogradFp16(&use_winograd, &out_unit, conv_param); - in_shape = {batch, inputs.front()->Height(), inputs.front()->Width(), new_in_channel}; - out_shape = {batch, conv_param->output_h_, conv_param->output_w_, new_out_channel}; - } - std::vector filter_shape = {new_out_channel, conv_param->kernel_h_, conv_param->kernel_w_, new_in_channel}; - std::vector bias_shape = {new_out_channel}; - - // new group conv op - std::vector group_convs; - // create tensors for every sub conv kernel - for (int i = 0; i < group; ++i) { - std::vector new_inputs; - std::vector new_outputs; - auto new_conv_parameter = CreateNewConvParameterFp16(conv_param); - if (new_conv_parameter == nullptr) { - FreeMemoryFp16(group_convs, new_inputs, new_outputs); - MS_LOG(ERROR) << "Get new conv parameter failed."; - return nullptr; - } - // create new input for each group - auto in_tensor = CreateInputTensorFp16(mindspore::kNumberTypeFloat16, in_shape, infered_flag); - if (in_tensor == nullptr) { - delete new_conv_parameter; - FreeMemoryFp16(group_convs, new_inputs, new_outputs); - MS_LOG(ERROR) << "create input tensor failed."; - return nullptr; - } - new_inputs.emplace_back(in_tensor); - - // create new weight - int copy_length = conv_param->kernel_h_ * conv_param->kernel_w_ * new_in_channel * new_out_channel; - auto filter_tensor = - CreateFilterTensorFp16(inputs.at(kWeightIndex)->data_type(), filter_shape, inputs, copy_length, i); - if (filter_tensor == nullptr) { - delete new_conv_parameter; - FreeMemoryFp16(group_convs, new_inputs, new_outputs); - MS_LOG(ERROR) << "create filter tensor failed."; - return nullptr; - } - new_inputs.emplace_back(filter_tensor); - - // if has bias, create new bias - if (has_bias) { - auto bias_tensor = - CreateBiasTensorFp16(inputs.at(kBiasIndex)->data_type(), bias_shape, inputs, new_out_channel, i); - if (bias_tensor == nullptr) { - delete new_conv_parameter; - FreeMemoryFp16(group_convs, new_inputs, new_outputs); - MS_LOG(ERROR) << "create bias_tensor failed."; - return nullptr; - } - new_inputs.emplace_back(bias_tensor); - } - - // create new output tensors - for (size_t j = 0; j < outputs.size(); ++j) { - auto out_tensor = CreateOutputTensorFp16(out_shape, outputs, infered_flag, j); - if (out_tensor == nullptr) { - delete new_conv_parameter; - FreeMemoryFp16(group_convs, new_inputs, new_outputs); - MS_LOG(ERROR) << "new out_tensor failed."; - return nullptr; - } - new_outputs.emplace_back(out_tensor); - } - group_convs.emplace_back(CpuConvFp16KernelSelect(new_inputs, new_outputs, - reinterpret_cast(new_conv_parameter), ctx, - primitive, use_winograd, out_unit)); - } - - return new (std::nothrow) - GroupConvolutionFP16CPUKernel(op_parameter, inputs, outputs, ctx, primitive, group_convs, group); -} - -kernel::LiteKernel *CpuConvFp16KernelCreator(const std::vector &inputs, - const std::vector &outputs, OpParameter *opParameter, - const InnerContext *ctx, const kernel::KernelKey &desc, - const mindspore::lite::PrimitiveC *primitive) { - MS_ASSERT(opParameter != nullptr); - MS_ASSERT(desc.type == schema::PrimitiveType_Conv2D); - - auto *weight_tensor = inputs.at(kWeightIndex); - auto *restore_data = weight_tensor->data_c(); - auto restore_type = weight_tensor->data_type(); - bool dequant_flag = - !weight_tensor->quant_params().empty() && weight_tensor->quant_params().front().inited && restore_data != nullptr; - if (dequant_flag) { - auto *dequant_weight = kernel::DequantUtil::DequantWeight(weight_tensor); - if (dequant_weight == nullptr) { - MS_LOG(ERROR) << "dequant data is nullptr."; - free(opParameter); - return nullptr; - } - weight_tensor->set_data_type(kNumberTypeFloat32); - weight_tensor->set_data(dequant_weight); - } - - auto conv_param = reinterpret_cast(opParameter); - bool use_winograd = false; - int out_unit; - if (primitive != nullptr && primitive->infer_flag()) { - conv_param->input_h_ = inputs.front()->Height(); - conv_param->input_w_ = inputs.front()->Width(); - conv_param->input_channel_ = inputs.front()->Channel(); - conv_param->output_h_ = outputs.front()->Height(); - conv_param->output_w_ = outputs.front()->Width(); - conv_param->output_channel_ = outputs.front()->Channel(); - conv_param->op_parameter_.thread_num_ = ctx->thread_num_; - CheckIfUseWinogradFp16(&use_winograd, &out_unit, conv_param); - } - int group = conv_param->group_; - kernel::LiteKernel *kernel = nullptr; - if (group == 1) { - kernel = CpuConvFp16KernelSelect(inputs, outputs, opParameter, ctx, primitive, use_winograd, out_unit); - } else { - kernel = CpuGroupConvFp16KernelCreator(inputs, outputs, opParameter, ctx, primitive, group); - } - - if (kernel == nullptr) { - MS_LOG(DEBUG) << "Create conv fp16 kernel failed."; - if (dequant_flag) { - weight_tensor->FreeData(); - weight_tensor->set_data(restore_data); - weight_tensor->set_data_type(restore_type); - } - free(opParameter); - return nullptr; - } - auto ret = kernel->Init(); - if (ret != RET_OK) { - MS_LOG(INFO) << "Init fp16 kernel failed, name: " << opParameter->name_ - << ", type: " << schema::EnumNamePrimitiveType(static_cast(opParameter->type_)); - if (dequant_flag) { - weight_tensor->FreeData(); - weight_tensor->set_data(restore_data); - weight_tensor->set_data_type(restore_type); - } - delete kernel; - return nullptr; - } - if (dequant_flag) { - weight_tensor->FreeData(); - weight_tensor->set_data(restore_data); - weight_tensor->set_data_type(restore_type); - } - return kernel; -} -REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_Conv2D, CpuConvFp16KernelCreator) } // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_fp16.h b/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_fp16.h index 8b13f1578f..3d998846c8 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_fp16.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_fp16.h @@ -27,13 +27,11 @@ class ConvolutionFP16CPUKernel : public ConvolutionBaseFP16CPUKernel { public: ConvolutionFP16CPUKernel(OpParameter *parameter, const std::vector &inputs, const std::vector &outputs, const InnerContext *ctx, - const mindspore::lite::PrimitiveC *primitive) - : ConvolutionBaseFP16CPUKernel(parameter, inputs, outputs, ctx, primitive) {} + const mindspore::lite::PrimitiveC *primitive, float16_t *fp16_weight, float16_t *fp16_bias) + : ConvolutionBaseFP16CPUKernel(parameter, inputs, outputs, ctx, primitive), + fp16_weight_(fp16_weight), + fp16_bias_(fp16_bias) {} ~ConvolutionFP16CPUKernel() override { - if (fp16_weight_ != nullptr) { - free(fp16_weight_); - fp16_weight_ = nullptr; - } if (packed_weight_ != nullptr) { free(packed_weight_); packed_weight_ = nullptr; @@ -58,6 +56,8 @@ class ConvolutionFP16CPUKernel : public ConvolutionBaseFP16CPUKernel { col_major_input_ = nullptr; } } + float16_t *fp16_weight_; // do not free + float16_t *fp16_bias_; // do not free float16_t *packed_input_ = nullptr; float16_t *packed_weight_ = nullptr; float16_t *col_major_input_ = nullptr; diff --git a/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_winograd_fp16.cc b/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_winograd_fp16.cc index 48d9111aca..5706c4c29c 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_winograd_fp16.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_winograd_fp16.cc @@ -43,12 +43,6 @@ int ConvolutionWinogradFP16CPUKernel::InitWeightBias() { int oc_block_num = UP_DIV(out_channel, C8NUM); // init weight - auto ret = ConvolutionBaseFP16CPUKernel::GetExecuteFilter(); - if (ret != RET_OK) { - MS_LOG(ERROR) << "Get Execute filter failed."; - return ret; - } - // set data auto trans_matrix_data_size = input_unit_ * input_unit_ * in_channel * oc_block_num * oc_block * sizeof(float16_t); trans_weight_ = reinterpret_cast(malloc(trans_matrix_data_size)); @@ -68,21 +62,17 @@ int ConvolutionWinogradFP16CPUKernel::InitWeightBias() { if (input_unit_ == 8) { coef = 0.5f; } - ret = CookToomFilter(matrix_a, matrix_at, matrix_b, matrix_bt, matrix_g, matrix_gt, coef, output_unit_, kernel_unit_); + auto ret = + CookToomFilter(matrix_a, matrix_at, matrix_b, matrix_bt, matrix_g, matrix_gt, coef, output_unit_, kernel_unit_); if (ret != RET_OK) { MS_LOG(ERROR) << "get matrix g from CookToomFilter failed."; return ret; } - ret = WinogradFilterTransformFp16(execute_weight_, matrix_g, matrix_gt, oc_block); + ret = WinogradFilterTransformFp16(fp16_origin_weight_, matrix_g, matrix_gt, oc_block); if (ret != RET_OK) { MS_LOG(ERROR) << "winograd filter transfrom failed."; return ret; } - if (fp16_weight_ != nullptr) { - free(fp16_weight_); - fp16_weight_ = nullptr; - execute_weight_ = nullptr; - } // init bias bias_data_ = malloc(oc_block_num * oc_block * sizeof(float16_t)); @@ -93,10 +83,7 @@ int ConvolutionWinogradFP16CPUKernel::InitWeightBias() { memset(bias_data_, 0, oc_block_num * oc_block * sizeof(float16_t)); auto fp16_bias_data = reinterpret_cast(bias_data_); if (in_tensors_.size() == kInputSize2) { - auto ori_bias = reinterpret_cast(in_tensors_.at(kBiasIndex)->MutableData()); - for (int i = 0; i < out_channel; ++i) { - fp16_bias_data[i] = (float16_t)ori_bias[i]; - } + memcpy(fp16_bias_data, fp16_bias_, out_channel * sizeof(float16_t)); } else { MS_ASSERT(in_tensors_.size() == kInputSize1); } @@ -163,15 +150,13 @@ int ConvolutionWinogradFP16CPUKernel::Init() { input_unit_ = output_unit_ + kernel_unit_ - 1; conv_param_->input_unit_ = input_unit_; conv_param_->output_unit_ = output_unit_; + auto ret = InitWeightBias(); if (ret != RET_OK) { MS_LOG(ERROR) << "Init weight bias failed."; return RET_ERROR; } - if (!InferShapeDone()) { - return RET_OK; - } - return ReSize(); + return RET_OK; } int ConvolutionWinogradFP16CPUKernel::ReSize() { @@ -180,17 +165,11 @@ int ConvolutionWinogradFP16CPUKernel::ReSize() { MS_LOG(ERROR) << "Resize is invalid."; return ret; } - ret = ConvolutionBaseCPUKernel::Init(); if (ret != RET_OK) { MS_LOG(ERROR) << "ConvolutionBase init failed."; return RET_ERROR; } - kernel_unit_ = conv_param_->kernel_h_; - input_unit_ = output_unit_ + kernel_unit_ - 1; - conv_param_->input_unit_ = input_unit_; - conv_param_->output_unit_ = output_unit_; - ret = ConfigInputOutput(); if (ret != RET_OK) { MS_LOG(ERROR) << "ConfigInputOutput failed."; diff --git a/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_winograd_fp16.h b/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_winograd_fp16.h index 567e5a7a9f..2e15807a47 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_winograd_fp16.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_winograd_fp16.h @@ -31,13 +31,13 @@ class ConvolutionWinogradFP16CPUKernel : public ConvolutionBaseFP16CPUKernel { public: ConvolutionWinogradFP16CPUKernel(OpParameter *parameter, const std::vector &inputs, const std::vector &outputs, const InnerContext *ctx, - const mindspore::lite::PrimitiveC *primitive, int out_unit) - : ConvolutionBaseFP16CPUKernel(parameter, inputs, outputs, ctx, primitive), output_unit_(out_unit) {} + const mindspore::lite::PrimitiveC *primitive, int out_unit, float16_t *fp16_weight, + float16_t *fp16_bias) + : ConvolutionBaseFP16CPUKernel(parameter, inputs, outputs, ctx, primitive), + output_unit_(out_unit), + fp16_origin_weight_(fp16_weight), + fp16_bias_(fp16_bias) {} ~ConvolutionWinogradFP16CPUKernel() override { - if (fp16_weight_ != nullptr) { - free(fp16_weight_); - fp16_weight_ = nullptr; - } if (trans_weight_ != nullptr) { free(trans_weight_); trans_weight_ = nullptr; @@ -75,6 +75,8 @@ class ConvolutionWinogradFP16CPUKernel : public ConvolutionBaseFP16CPUKernel { int kernel_unit_; int input_unit_; int output_unit_; + float16_t *fp16_origin_weight_; // do not free + float16_t *fp16_bias_; // do not free float16_t *tmp_data_ = nullptr; float16_t *trans_input_ = nullptr; float16_t *gemm_out_ = nullptr; diff --git a/mindspore/lite/src/runtime/kernel/arm/fp16/group_convolution_fp16.cc b/mindspore/lite/src/runtime/kernel/arm/fp16/group_convolution_fp16.cc index 3fded17066..e516641786 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp16/group_convolution_fp16.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp16/group_convolution_fp16.cc @@ -87,11 +87,8 @@ int GroupConvolutionFP16CPUKernel::PreProcess() { std::vector out_shape; for (int i = 0; i < group_num_; ++i) { // in - int in_batch = conv_param_->input_batch_; - int in_h = conv_param_->input_h_; - int in_w = conv_param_->input_w_; - int in_c = conv_param_->input_channel_; - in_shape = {in_batch, in_h, in_w, in_c}; + auto in_tensor = in_tensors_.front(); + in_shape = {in_tensor->Batch(), in_tensor->Height(), in_tensor->Width(), conv_param_->input_channel_}; auto sub_kernel_in_tensor = group_convs_.at(i)->in_tensors().front(); sub_kernel_in_tensor->set_shape(in_shape); ret = sub_kernel_in_tensor->MallocData(); @@ -101,11 +98,8 @@ int GroupConvolutionFP16CPUKernel::PreProcess() { return ret; } // out - int out_batch = conv_param_->output_batch_; - int out_h = conv_param_->output_h_; - int out_w = conv_param_->output_w_; - int out_c = conv_param_->output_channel_; - out_shape = {out_batch, out_h, out_w, out_c}; + auto out_tensor = out_tensors_.front(); + out_shape = {out_tensor->Batch(), out_tensor->Height(), out_tensor->Width(), conv_param_->output_channel_}; auto sub_kernel_out_tensors = group_convs_[i]->out_tensors(); for (auto tensor : sub_kernel_out_tensors) { tensor->set_shape(out_shape); @@ -139,7 +133,8 @@ int GroupConvolutionFP16CPUKernel::PreProcess() { int GroupConvolutionFP16CPUKernel::SeparateInput(int group_id) { // input may either be float32 or float16 - int in_plane = conv_param_->input_h_ * conv_param_->input_w_ * conv_param_->input_batch_; + auto in_tensor = in_tensors_.front(); + int in_plane = in_tensor->Height() * in_tensor->Width() * in_tensor->Batch(); int sub_in_channel = conv_param_->input_channel_; int ori_in_channel = sub_in_channel * group_num_; auto sub_in_data = group_convs_.at(group_id)->in_tensors().front()->data_c(); @@ -179,7 +174,8 @@ int GroupConvolutionFP16CPUKernel::SeparateInput(int group_id) { void GroupConvolutionFP16CPUKernel::PostConcat(int group_id) { // output is must float16 data type - int out_plane = conv_param_->output_h_ * conv_param_->output_w_ * conv_param_->output_batch_; + auto out_tensor = out_tensors_.front(); + int out_plane = out_tensor->Height() * out_tensor->Width() * out_tensor->Batch(); int sub_out_channel = conv_param_->output_channel_; int ori_out_channel = sub_out_channel * group_num_; auto sub_out_data = reinterpret_cast(group_convs_.at(group_id)->out_tensors().front()->data_c()); diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/adder_fp32.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/adder_fp32.cc index e059ba7af3..01ced8bb4a 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/adder_fp32.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/adder_fp32.cc @@ -31,6 +31,33 @@ using mindspore::schema::PrimitiveType_Adder; using mindspore::schema::Format::Format_NHWC; namespace mindspore::kernel { +int AdderCPUKernel::Init() { + auto ret = InitWeightBias(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Init weight bias failed."; + return RET_ERROR; + } + if (!InferShapeDone()) { + return RET_OK; + } + return ReSize(); +} + +int AdderCPUKernel::ReSize() { + auto ret = ConvolutionBaseCPUKernel::CheckResizeValid(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Resize is invalid."; + return ret; + } + + ret = ConvolutionBaseCPUKernel::Init(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "ConvolutionBase init failed."; + return ret; + } + return RET_OK; +} + int AdderCPUKernel::InitWeightBias() { auto filter_tensor = in_tensors_.at(kWeightIndex); int kernel_h = filter_tensor->Height(); diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/adder_fp32.h b/mindspore/lite/src/runtime/kernel/arm/fp32/adder_fp32.h index 7f2b8c4363..468e369793 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/adder_fp32.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/adder_fp32.h @@ -29,10 +29,12 @@ class AdderCPUKernel : public ConvolutionCPUKernel { AdderCPUKernel(OpParameter *parameter, const std::vector &inputs, const std::vector &outputs, const lite::InnerContext *ctx, const mindspore::lite::PrimitiveC *primitive) - : ConvolutionCPUKernel(parameter, inputs, outputs, ctx, primitive) {} + : ConvolutionCPUKernel(parameter, inputs, outputs, ctx, primitive, nullptr, nullptr) {} ~AdderCPUKernel() override = default; int InitWeightBias() override; + int Init() override; + int ReSize() override; int Run() override; int RunImpl(int task_id) override; }; diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_1x1_fp32.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_1x1_fp32.cc index e6a6b66db4..981b2802d4 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_1x1_fp32.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_1x1_fp32.cc @@ -44,10 +44,13 @@ void Convolution1x1CPUKernel::FreeTmpBuffer() { int Convolution1x1CPUKernel::ReSize() { FreeTmpBuffer(); - ConvolutionBaseCPUKernel::Init(); + auto error_code = ConvolutionBaseCPUKernel::Init(); + if (error_code != RET_OK) { + MS_LOG(ERROR) << "conv base init failed."; + return error_code; + } InitConv1x1MatmulParam(); - - int error_code = InitConv1x1Param(); + error_code = InitConv1x1Param(); if (error_code != RET_OK) { MS_LOG(ERROR) << "Convolution base init failed."; return error_code; @@ -95,7 +98,7 @@ int Convolution1x1CPUKernel::InitConv1x1BiasWeight() { MS_LOG(ERROR) << "Conv1x1 Malloc bias_ptr_ error!"; return RET_ERROR; } - memcpy(bias_data_, in_tensors_[kBiasIndex]->MutableData(), weight_size); + memcpy(bias_data_, origin_bias_, weight_size); memset(reinterpret_cast(bias_data_) + weight_size, 0, size - weight_size); } @@ -108,14 +111,11 @@ int Convolution1x1CPUKernel::InitConv1x1BiasWeight() { } memset(reinterpret_cast(weight_ptr_) + down_size, 0, size - down_size); #ifdef ENABLE_AVX - RowMajor2Col16Major(reinterpret_cast(filter_tensor->MutableData()), weight_ptr_, output_channel, - input_channel); + RowMajor2Col16Major(origin_weight_, weight_ptr_, output_channel, input_channel); #elif defined(ENABLE_ARM32) - RowMajor2Col4Major(reinterpret_cast(filter_tensor->MutableData()), weight_ptr_, output_channel, - input_channel); + RowMajor2Col4Major(origin_weight_, weight_ptr_, output_channel, input_channel); #else - RowMajor2Col8Major(reinterpret_cast(filter_tensor->MutableData()), weight_ptr_, output_channel, - input_channel); + RowMajor2Col8Major(origin_weight_, weight_ptr_, output_channel, input_channel); #endif return RET_OK; } @@ -153,13 +153,10 @@ int Convolution1x1CPUKernel::Init() { } int error_code = InitConv1x1BiasWeight(); if (error_code != RET_OK) { - MS_LOG(ERROR) << "Convolution base init failed."; + MS_LOG(ERROR) << "Convolution1x1 init weight and bias failed."; return error_code; } - if (!InferShapeDone()) { - return RET_OK; - } - return ReSize(); + return RET_OK; } void Convolution1x1CPUKernel::PackMatmulInput(const float *src_ptr, float *dst_ptr, int row, int col) { diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_1x1_fp32.h b/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_1x1_fp32.h index 1d8b82aa7f..e048214c96 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_1x1_fp32.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_1x1_fp32.h @@ -34,8 +34,10 @@ class Convolution1x1CPUKernel : public ConvolutionBaseCPUKernel { public: Convolution1x1CPUKernel(OpParameter *parameter, const std::vector &inputs, const std::vector &outputs, const lite::InnerContext *ctx, - const mindspore::lite::PrimitiveC *primitive) - : ConvolutionBaseCPUKernel(parameter, inputs, outputs, ctx, primitive) {} + const mindspore::lite::PrimitiveC *primitive, float *origin_weight, float *origin_bias) + : ConvolutionBaseCPUKernel(parameter, inputs, outputs, ctx, primitive), + origin_weight_(origin_weight), + origin_bias_(origin_bias) {} ~Convolution1x1CPUKernel(); int Init() override; int Run() override; @@ -58,6 +60,8 @@ class Convolution1x1CPUKernel : public ConvolutionBaseCPUKernel { bool multi_thread_by_hw_ = false; int thread_count_ = 0; int thread_stride_ = 0; + float *origin_weight_; // do not free + float *origin_bias_; // do not free float *weight_ptr_ = nullptr; float *pack_input_ = nullptr; float *input_ptr_ = nullptr; diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_delegate_fp32.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_delegate_fp32.cc new file mode 100644 index 0000000000..948c45c0e0 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_delegate_fp32.cc @@ -0,0 +1,416 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "src/runtime/kernel/arm/fp32/convolution_delegate_fp32.h" +#include "src/runtime/kernel/arm/fp32/convolution_fp32.h" +#include "src/runtime/kernel/arm/fp32/convolution_1x1_fp32.h" +#include "src/runtime/kernel/arm/fp32/convolution_winograd_fp32.h" +#include "src/runtime/kernel/arm/fp32/group_convolution_fp32.h" +#include "schema/model_generated.h" +#include "src/kernel_registry.h" +#include "include/errorcode.h" +#include "src/runtime/runtime_api.h" +#include "src/runtime/kernel/arm/base/dequant.h" + +using mindspore::kernel::KERNEL_ARCH::kCPU; +using mindspore::lite::KernelRegistrar; +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_INFER_INVALID; +using mindspore::lite::RET_OK; +using mindspore::schema::PrimitiveType_Conv2D; +using mindspore::schema::Format::Format_NHWC; + +namespace mindspore::kernel { +float *ConvolutionDelegateCPUKernel::CopyData(lite::Tensor *tensor) { + auto data = reinterpret_cast(malloc(tensor->Size())); + if (data == nullptr) { + MS_LOG(ERROR) << "Malloc data failed."; + return nullptr; + } + memcpy(data, tensor->data_c(), tensor->Size()); + return data; +} + +void ConvolutionDelegateCPUKernel::FreeCopiedData() { + if (origin_weight_ != nullptr && need_free_weight_) { + free(origin_weight_); + origin_weight_ = nullptr; + } + if (origin_bias_ != nullptr && need_free_bias_) { + free(origin_bias_); + origin_bias_ = nullptr; + } +} + +int ConvolutionDelegateCPUKernel::GetWeightAndBias() { + auto ret = GetWeightData(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Get weight data failed."; + return ret; + } + ret = GetBiasData(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Get bias data failed."; + return ret; + } + return RET_OK; +} + +int ConvolutionDelegateCPUKernel::GetWeightData() { + if (InferShapeDone()) { + origin_weight_ = reinterpret_cast(in_tensors_.at(kWeightIndex)->data_c()); + return RET_OK; + } else { + origin_weight_ = CopyData(in_tensors_.at(kWeightIndex)); + if (origin_weight_ == nullptr) { + MS_LOG(ERROR) << "Copy weight data failed."; + return RET_ERROR; + } + need_free_weight_ = true; + return RET_OK; + } + return RET_OK; +} + +int ConvolutionDelegateCPUKernel::GetBiasData() { + if (in_tensors_.size() == 3) { + if (InferShapeDone()) { + origin_bias_ = reinterpret_cast(in_tensors_.at(kBiasIndex)->data_c()); + return RET_OK; + } else { + origin_bias_ = CopyData(in_tensors_.at(kBiasIndex)); + if (origin_bias_ == nullptr) { + MS_LOG(ERROR) << "Copy bias data failed."; + return RET_ERROR; + } + need_free_bias_ = true; + return RET_OK; + } + } + return RET_OK; +} + +int ConvolutionDelegateCPUKernel::Init() { + auto ret = GetWeightAndBias(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Get weight and bias failed."; + return ret; + } + if (!InferShapeDone()) { + return RET_OK; + } + return ReSize(); +} + +int ConvolutionDelegateCPUKernel::ReSize() { + // Updata shape info of input and output + SetInputOutputShapeInfo(reinterpret_cast(op_parameter_), in_tensors_.front(), out_tensors_.front(), + context_); + if (conv_kernel_ == nullptr) { + // need to select actual execute kernel here + conv_kernel_ = CpuConvFp32KernelSelect(in_tensors_, out_tensors_, op_parameter_, context_, primitive_, + origin_weight_, origin_bias_); + if (conv_kernel_ == nullptr) { + MS_LOG(ERROR) << "Selecting execute kernel failed for conv_kernel, got a nullptr."; + return RET_ERROR; + } + } + FreeCopiedData(); + return conv_kernel_->ReSize(); +} + +void SetInputOutputShapeInfo(ConvParameter *conv_param, const lite::Tensor *input, const lite::Tensor *output, + const InnerContext *ctx) { + conv_param->input_batch_ = input->Batch(); + conv_param->input_h_ = input->Height(); + conv_param->input_w_ = input->Width(); + conv_param->input_channel_ = input->Channel(); + conv_param->output_batch_ = output->Batch(); + conv_param->output_h_ = output->Height(); + conv_param->output_w_ = output->Width(); + conv_param->output_channel_ = output->Channel(); + conv_param->op_parameter_.thread_num_ = ctx->thread_num_; +} + +ConvParameter *CreateNewConvParameter(ConvParameter *parameter) { + auto conv_parameter = new (std::nothrow) ConvParameter; + if (conv_parameter == nullptr) { + MS_LOG(ERROR) << "Malloc new conv parameter failed."; + return nullptr; + } + memcpy(conv_parameter, parameter, sizeof(ConvParameter)); + return conv_parameter; +} + +void FreeMemory(const std::vector &group_convs, const std::vector &new_inputs, + const std::vector &new_outputs) { + for (auto sub_conv : group_convs) { + delete sub_conv; + } + for (auto in_tensor : new_inputs) { + delete in_tensor; + } + for (auto out_tensor : new_outputs) { + delete out_tensor; + } +} + +lite::Tensor *CreateInputTensor(TypeId data_type, const std::vector &in_shape, bool infered_flag) { + auto in_tensor = new (std::nothrow) lite::Tensor(data_type, in_shape, Format_NHWC, lite::Tensor::Category::VAR); + if (in_tensor == nullptr) { + MS_LOG(ERROR) << "new in_tensor failed."; + return nullptr; + } + if (infered_flag) { + auto ret = in_tensor->MallocData(); + if (ret != RET_OK) { + delete in_tensor; + MS_LOG(ERROR) << "in tensor malloc failed."; + return nullptr; + } + } + return in_tensor; +} + +// weight and bias are const +static lite::Tensor *CreateConstTensorFp32(lite::Tensor *tensor, const std::vector &shape, const int index) { + auto new_tensor = + new (std::nothrow) lite::Tensor(tensor->data_type(), shape, Format_NHWC, lite::Tensor::Category::CONST_TENSOR); + if (new_tensor == nullptr) { + MS_LOG(ERROR) << "Create new_tensor failed."; + return nullptr; + } + auto ret = new_tensor->MallocData(); + if (ret != RET_OK) { + delete new_tensor; + MS_LOG(ERROR) << "Malloc new_tensor failed."; + return nullptr; + } + MS_ASSERT(tensor->data_type() == kNumberTypeFloat32); + memcpy(new_tensor->data_c(), reinterpret_cast(tensor->data_c()) + index * new_tensor->Size(), + new_tensor->Size()); + return new_tensor; +} + +lite::Tensor *CreateOutputTensor(const std::vector &out_shape, const std::vector &outputs, + bool infered_flag, int index) { + auto out_tensor = new (std::nothrow) lite::Tensor(); + if (out_tensor == nullptr) { + MS_LOG(ERROR) << "new tmp_out_tensor failed."; + return nullptr; + } + out_tensor->set_data_type(outputs.at(index)->data_type()); + out_tensor->set_format(outputs.at(index)->format()); + if (infered_flag) { + out_tensor->set_shape(out_shape); + auto ret = out_tensor->MallocData(); + if (ret != RET_OK) { + delete out_tensor; + MS_LOG(ERROR) << "out_tensor malloc data failed."; + return nullptr; + } + } + return out_tensor; +} + +kernel::LiteKernel *CpuConvFp32KernelSelect(const std::vector &inputs, + const std::vector &outputs, OpParameter *op_parameter, + const InnerContext *ctx, const mindspore::lite::PrimitiveC *primitive, + float *origin_weight, float *origin_bias) { + auto conv_param = reinterpret_cast(op_parameter); + bool use_winograd = false; + int out_unit; + CheckIfUseWinograd(&use_winograd, &out_unit, conv_param); + kernel::LiteKernel *kernel = nullptr; + if (conv_param->kernel_h_ == 1 && conv_param->kernel_w_ == 1) { + kernel = new (std::nothrow) + kernel::Convolution1x1CPUKernel(op_parameter, inputs, outputs, ctx, primitive, origin_weight, origin_bias); + } else if (use_winograd) { + kernel = new (std::nothrow) kernel::ConvolutionWinogradCPUKernel(op_parameter, inputs, outputs, ctx, primitive, + out_unit, origin_weight, origin_bias); + } else { + kernel = new (std::nothrow) + kernel::ConvolutionCPUKernel(op_parameter, inputs, outputs, ctx, primitive, origin_weight, origin_bias); + } + if (kernel != nullptr) { + auto ret = kernel->Init(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "conv kernel init failed."; + delete kernel; + return nullptr; + } + } + return kernel; +} + +static kernel::LiteKernel *CreateDelegateConv(const std::vector &inputs, + const std::vector &outputs, OpParameter *op_parameter, + const InnerContext *ctx, const mindspore::lite::PrimitiveC *primitive) { + return new (std::nothrow) kernel::ConvolutionDelegateCPUKernel(op_parameter, inputs, outputs, ctx, primitive); +} + +kernel::LiteKernel *CpuGroupConvFp32KernelCreator(const std::vector &inputs, + const std::vector &outputs, OpParameter *op_parameter, + const InnerContext *ctx, + const mindspore::lite::PrimitiveC *primitive) { + bool infer_flag = primitive != nullptr && primitive->infer_flag(); + auto conv_param = reinterpret_cast(op_parameter); + int new_in_channel = inputs.at(kWeightIndex)->Channel(); + int new_out_channel; + if (conv_param->group_ == 0) { + MS_LOG(ERROR) << "Divisor 'group' cannot be 0."; + return nullptr; + } else { + new_out_channel = inputs.at(kWeightIndex)->Batch() / conv_param->group_; + } + std::vector in_shape; + std::vector out_shape; + if (infer_flag) { + conv_param->input_channel_ = new_in_channel; + conv_param->output_channel_ = new_out_channel; + in_shape = {inputs.front()->Batch(), inputs.front()->Height(), inputs.front()->Width(), new_in_channel}; + out_shape = {inputs.front()->Batch(), outputs.front()->Height(), outputs.front()->Width(), new_out_channel}; + } + std::vector filter_shape = {new_out_channel, conv_param->kernel_h_, conv_param->kernel_w_, new_in_channel}; + std::vector bias_shape = {new_out_channel}; + + // create sub kernels + std::vector group_convs; + for (int i = 0; i < conv_param->group_; ++i) { + std::vector new_inputs; + std::vector new_outputs; + auto new_conv_parameter = CreateNewConvParameter(conv_param); + if (new_conv_parameter == nullptr) { + FreeMemory(group_convs, new_inputs, new_outputs); + MS_LOG(ERROR) << "Get new conv parameter failed."; + return nullptr; + } + + // create new input for each group + auto in_tensor = CreateInputTensor(inputs.front()->data_type(), in_shape, infer_flag); + if (in_tensor == nullptr) { + delete new_conv_parameter; + FreeMemory(group_convs, new_inputs, new_outputs); + MS_LOG(ERROR) << "create input tensor failed."; + return nullptr; + } + new_inputs.emplace_back(in_tensor); + + // create new weight + auto filter_tensor = CreateConstTensorFp32(inputs.at(kWeightIndex), filter_shape, i); + if (filter_tensor == nullptr) { + delete new_conv_parameter; + FreeMemory(group_convs, new_inputs, new_outputs); + MS_LOG(ERROR) << "create filter tensor failed."; + return nullptr; + } + new_inputs.emplace_back(filter_tensor); + + // if has bias, create new bias + if (inputs.size() == 3) { + auto bias_tensor = CreateConstTensorFp32(inputs.at(kBiasIndex), bias_shape, i); + if (bias_tensor == nullptr) { + delete new_conv_parameter; + FreeMemory(group_convs, new_inputs, new_outputs); + MS_LOG(ERROR) << "create bias_tensor failed."; + return nullptr; + } + new_inputs.emplace_back(bias_tensor); + } + + // create new output tensor + for (size_t j = 0; j < outputs.size(); ++j) { + auto out_tensor = CreateOutputTensor(out_shape, outputs, infer_flag, j); + if (out_tensor == nullptr) { + delete new_conv_parameter; + FreeMemory(group_convs, new_inputs, new_outputs); + MS_LOG(ERROR) << "new out_tensor failed."; + return nullptr; + } + new_outputs.emplace_back(out_tensor); + } + group_convs.emplace_back( + CreateDelegateConv(new_inputs, new_outputs, reinterpret_cast(new_conv_parameter), ctx, primitive)); + } + return new (std::nothrow) + GroupConvolutionCPUKernel(op_parameter, inputs, outputs, ctx, primitive, group_convs, conv_param->group_); +} + +kernel::LiteKernel *CpuConvFp32KernelCreator(const std::vector &inputs, + const std::vector &outputs, OpParameter *op_parameter, + const InnerContext *ctx, const kernel::KernelKey &desc, + const mindspore::lite::PrimitiveC *primitive) { + MS_ASSERT(op_parameter != nullptr); + MS_ASSERT(desc.type == schema::PrimitiveType_Conv2D); + MS_ASSERT(desc.data_type == kNumberTypeFloat32); + + // if get quantized weight, dequantize it to float32 type data. + auto *weight_tensor = inputs.at(kWeightIndex); + auto *restore_data = weight_tensor->data_c(); + auto restore_type = weight_tensor->data_type(); + bool dequant_flag = + !weight_tensor->quant_params().empty() && weight_tensor->quant_params().front().inited && restore_data != nullptr; + if (dequant_flag) { + auto *dequant_weight = kernel::DequantUtil::DequantWeight(weight_tensor); + if (dequant_weight == nullptr) { + MS_LOG(ERROR) << "dequant data is nullptr."; + free(op_parameter); + return nullptr; + } + weight_tensor->set_data(dequant_weight); + } + + auto conv_param = reinterpret_cast(op_parameter); + kernel::LiteKernel *kernel = nullptr; + if (conv_param->group_ == 1) { + kernel = CreateDelegateConv(inputs, outputs, op_parameter, ctx, primitive); + } else { + kernel = CpuGroupConvFp32KernelCreator(inputs, outputs, op_parameter, ctx, primitive); + } + + if (kernel == nullptr) { + MS_LOG(ERROR) << "kernel is nullptr."; + if (dequant_flag) { + weight_tensor->FreeData(); + weight_tensor->set_data(restore_data); + weight_tensor->set_data_type(restore_type); + } + free(op_parameter); + return nullptr; + } + + auto ret = kernel->Init(); + if (ret != RET_OK && ret != RET_INFER_INVALID) { + MS_LOG(ERROR) << "Init kernel failed, name: " << op_parameter->name_ << ", type: " + << schema::EnumNamePrimitiveType(static_cast(op_parameter->type_)); + if (dequant_flag) { + weight_tensor->FreeData(); + weight_tensor->set_data(restore_data); + weight_tensor->set_data_type(restore_type); + } + delete kernel; + return nullptr; + } + + if (dequant_flag) { + weight_tensor->FreeData(); + weight_tensor->set_data(restore_data); + weight_tensor->set_data_type(restore_type); + } + return kernel; +} + +REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Conv2D, CpuConvFp32KernelCreator) +} // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_delegate_fp32.h b/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_delegate_fp32.h new file mode 100644 index 0000000000..89a4d1c2c2 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_delegate_fp32.h @@ -0,0 +1,77 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONVOLUTION_DELEGATE_FP32_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONVOLUTION_DELEGATE_FP32_H_ + +#include +#include "src/ops/conv2d.h" +#include "src/lite_kernel.h" +#include "nnacl/conv_parameter.h" +#include "nnacl/op_base.h" + +using mindspore::lite::InnerContext; +namespace mindspore::kernel { +class ConvolutionDelegateCPUKernel : public LiteKernel { + public: + ConvolutionDelegateCPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs, const lite::InnerContext *ctx, + const mindspore::lite::PrimitiveC *primitive) + : LiteKernel(parameter, inputs, outputs, ctx, primitive) {} + ~ConvolutionDelegateCPUKernel() override { + FreeCopiedData(); + if (conv_kernel_ != nullptr) { + op_parameter_ = nullptr; // op_parameter will be freed in conv_kernel + delete conv_kernel_; + conv_kernel_ = nullptr; + } + }; + int Init() override; + int ReSize() override; + int Run() override { return conv_kernel_->Run(); } + int GetWeightAndBias(); + int GetWeightData(); + int GetBiasData(); + static float *CopyData(lite::Tensor *tensor); + void FreeCopiedData(); + + protected: + bool need_free_weight_ = false; + bool need_free_bias_ = false; + kernel::LiteKernel *conv_kernel_ = nullptr; + float *origin_weight_ = nullptr; + float *origin_bias_ = nullptr; +}; + +void SetInputOutputShapeInfo(ConvParameter *conv_param, const lite::Tensor *input, const lite::Tensor *output, + const InnerContext *ctx); + +void FreeMemory(const std::vector &group_convs, const std::vector &new_inputs, + const std::vector &new_outputs); + +ConvParameter *CreateNewConvParameter(ConvParameter *parameter); + +lite::Tensor *CreateInputTensor(TypeId data_type, const std::vector &in_shape, bool infered_flag); + +lite::Tensor *CreateOutputTensor(const std::vector &out_shape, const std::vector &outputs, + bool infered_flag, int index); + +kernel::LiteKernel *CpuConvFp32KernelSelect(const std::vector &inputs, + const std::vector &outputs, OpParameter *op_parameter, + const InnerContext *ctx, const mindspore::lite::PrimitiveC *primitive, + float *origin_weight, float *origin_bias); +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONVOLUTION_DELEGATE_FP32_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_fp32.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_fp32.cc index fa3dc8242a..9c8416038a 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_fp32.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_fp32.cc @@ -15,16 +15,13 @@ */ #include "src/runtime/kernel/arm/fp32/convolution_fp32.h" -#include "src/runtime/kernel/arm/fp32/convolution_1x1_fp32.h" -#include "src/runtime/kernel/arm/fp32/convolution_winograd_fp32.h" -#include "src/runtime/kernel/arm/fp32/group_convolution_fp32.h" -#include "nnacl/fp32/conv_fp32.h" -#include "nnacl/common_func.h" +#include "include/errorcode.h" #include "schema/model_generated.h" #include "src/kernel_registry.h" -#include "include/errorcode.h" #include "src/runtime/runtime_api.h" #include "src/runtime/kernel/arm/base/dequant.h" +#include "nnacl/fp32/conv_fp32.h" +#include "nnacl/fp32/matmul_fp32.h" using mindspore::kernel::KERNEL_ARCH::kCPU; using mindspore::lite::KernelRegistrar; @@ -52,7 +49,6 @@ int ConvolutionCPUKernel::InitWeightBias() { int oc_block_num = UP_ROUND(out_channel, oc_block); int pack_weight_size = oc_block_num * in_channel * kernel_plane; - auto origin_weight = reinterpret_cast(filter_tensor->data_c()); packed_weight_ = reinterpret_cast(malloc(pack_weight_size * sizeof(float))); if (packed_weight_ == nullptr) { MS_LOG(ERROR) << "malloc packed weight failed."; @@ -60,11 +56,11 @@ int ConvolutionCPUKernel::InitWeightBias() { } memset(packed_weight_, 0, pack_weight_size * sizeof(float)); #ifdef ENABLE_AVX - RowMajor2Col16Major(origin_weight, packed_weight_, out_channel, in_channel * kernel_plane); + RowMajor2Col16Major(origin_weight_, packed_weight_, out_channel, in_channel * kernel_plane); #elif ENABLE_ARM32 - RowMajor2Col4Major(origin_weight, packed_weight_, out_channel, in_channel * kernel_plane); + RowMajor2Col4Major(origin_weight_, packed_weight_, out_channel, in_channel * kernel_plane); #else - RowMajor2Col8Major(origin_weight, packed_weight_, out_channel, in_channel * kernel_plane); + RowMajor2Col8Major(origin_weight_, packed_weight_, out_channel, in_channel * kernel_plane); #endif bias_data_ = reinterpret_cast(malloc(oc_block_num * sizeof(float))); @@ -75,8 +71,7 @@ int ConvolutionCPUKernel::InitWeightBias() { memset(bias_data_, 0, oc_block_num * sizeof(float)); if (in_tensors_.size() == kInputSize2) { - auto ori_bias = reinterpret_cast(in_tensors_.at(kBiasIndex)->data_c()); - memcpy(bias_data_, ori_bias, out_channel * sizeof(float)); + memcpy(bias_data_, origin_bias_, out_channel * sizeof(float)); } else { MS_ASSERT(in_tensors_.size() == kInputSize1); } @@ -114,10 +109,7 @@ int ConvolutionCPUKernel::Init() { MS_LOG(ERROR) << "Init weight bias failed."; return RET_ERROR; } - if (!InferShapeDone()) { - return RET_OK; - } - return ReSize(); + return RET_OK; } int ConvolutionCPUKernel::ReSize() { @@ -126,11 +118,10 @@ int ConvolutionCPUKernel::ReSize() { MS_LOG(ERROR) << "Resize is invalid."; return ret; } - ret = ConvolutionBaseCPUKernel::Init(); if (ret != RET_OK) { - MS_LOG(ERROR) << "ConvolutionBase init failed."; - return RET_ERROR; + MS_LOG(ERROR) << "conv base init failed."; + return ret; } return RET_OK; } @@ -168,304 +159,4 @@ int ConvolutionCPUKernel::Run() { FreeTmpBuffer(); return ret; } - -ConvParameter *CreateNewConvParameter(ConvParameter *parameter) { - auto conv_parameter = new (std::nothrow) ConvParameter; - if (conv_parameter == nullptr) { - MS_LOG(ERROR) << "Malloc new conv parameter failed."; - return nullptr; - } - memcpy(conv_parameter, parameter, sizeof(ConvParameter)); - return conv_parameter; -} - -void FreeMemory(const std::vector &group_convs, const std::vector &new_inputs, - const std::vector &new_outputs) { - for (auto sub_conv : group_convs) { - delete sub_conv; - } - for (auto in_tensor : new_inputs) { - delete in_tensor; - } - for (auto out_tensor : new_outputs) { - delete out_tensor; - } -} - -lite::Tensor *CreateInputTensor(TypeId data_type, std::vector in_shape, bool infered_flag) { - auto in_tensor = new (std::nothrow) lite::Tensor(data_type, in_shape, Format_NHWC, lite::Tensor::Category::VAR); - if (in_tensor == nullptr) { - MS_LOG(ERROR) << "new in_tensor failed."; - return nullptr; - } - if (infered_flag) { - auto ret = in_tensor->MallocData(); - if (ret != RET_OK) { - delete in_tensor; - MS_LOG(ERROR) << "in tensor malloc failed."; - return nullptr; - } - } - return in_tensor; -} - -lite::Tensor *CreateFilterTensorFp32(TypeId data_type, std::vector filter_shape, - const std::vector &inputs, int copy_length, int index) { - auto filter_tensor = - new (std::nothrow) lite::Tensor(data_type, filter_shape, Format_NHWC, lite::Tensor::Category::CONST_TENSOR); - if (filter_tensor == nullptr) { - MS_LOG(ERROR) << "new filter_tensor failed."; - return nullptr; - } - auto ret = filter_tensor->MallocData(); - if (ret != RET_OK) { - delete filter_tensor; - MS_LOG(ERROR) << "filter_tensor malloc failed."; - return nullptr; - } - - MS_ASSERT(data_type == kNumberTypeFloat32); - auto *origin_weight = reinterpret_cast(inputs.at(kWeightIndex)->data_c()); - memcpy(filter_tensor->data_c(), origin_weight + index * copy_length, copy_length * sizeof(float)); - return filter_tensor; -} - -lite::Tensor *CreateBiasTensorFp32(TypeId data_type, std::vector bias_shape, - const std::vector &inputs, int new_out_channel, int index) { - auto *origin_bias = inputs.at(kBiasIndex)->data_c(); - auto bias_tensor = - new (std::nothrow) lite::Tensor(data_type, bias_shape, Format_NHWC, lite::Tensor::Category::CONST_TENSOR); - if (bias_tensor == nullptr) { - MS_LOG(ERROR) << "new bias_tensor failed."; - return nullptr; - } - auto ret = bias_tensor->MallocData(); - if (ret != RET_OK) { - delete bias_tensor; - MS_LOG(ERROR) << "bias_tensor malloc failed."; - return nullptr; - } - MS_ASSERT(data_type == kNumberTypeFloat32); - auto bias_data = reinterpret_cast(origin_bias); - memcpy(bias_tensor->data_c(), bias_data + index * new_out_channel, new_out_channel * sizeof(float)); - - return bias_tensor; -} - -lite::Tensor *CreateOutputTensor(std::vector out_shape, const std::vector &outputs, - bool infered_flag, int index) { - auto out_tensor = new (std::nothrow) lite::Tensor(); - if (out_tensor == nullptr) { - MS_LOG(ERROR) << "new tmp_out_tensor failed."; - return nullptr; - } - out_tensor->set_data_type(outputs.at(index)->data_type()); - out_tensor->set_format(outputs.at(index)->format()); - if (infered_flag) { - out_tensor->set_shape(out_shape); - auto ret = out_tensor->MallocData(); - if (ret != RET_OK) { - delete out_tensor; - MS_LOG(ERROR) << "out_tensor malloc data failed."; - return nullptr; - } - } - return out_tensor; -} - -kernel::LiteKernel *CpuConvFp32KernelSelect(const std::vector &inputs, - const std::vector &outputs, OpParameter *op_parameter, - const InnerContext *ctx, const mindspore::lite::PrimitiveC *primitive, - bool use_winograd, int out_unit) { - auto conv_param = reinterpret_cast(op_parameter); - if (conv_param->kernel_h_ == 1 && conv_param->kernel_w_ == 1) { - return new (std::nothrow) kernel::Convolution1x1CPUKernel(op_parameter, inputs, outputs, ctx, primitive); - } else if (use_winograd) { - return new (std::nothrow) - kernel::ConvolutionWinogradCPUKernel(op_parameter, inputs, outputs, ctx, primitive, out_unit); - } else { - return new (std::nothrow) kernel::ConvolutionCPUKernel(op_parameter, inputs, outputs, ctx, primitive); - } - return nullptr; -} - -kernel::LiteKernel *CpuGroupConvFp32KernelCreator(const std::vector &inputs, - const std::vector &outputs, OpParameter *op_parameter, - const InnerContext *ctx, const mindspore::lite::PrimitiveC *primitive, - int group) { - int out_unit; - bool has_bias = inputs.size() == 3; - bool use_winograd = false; - bool infered_flag = primitive != nullptr && primitive->infer_flag(); - auto conv_param = reinterpret_cast(op_parameter); - - std::vector in_shape; - std::vector out_shape; - int new_in_channel = inputs.at(kWeightIndex)->Channel(); - int new_out_channel = 0; - if (group == 0) { - MS_LOG(ERROR) << "Divisor 'group' cannot be 0."; - return nullptr; - } else { - new_out_channel = inputs.at(kWeightIndex)->Batch() / group; - } - int batch = inputs.front()->Batch(); - conv_param->input_batch_ = batch; - conv_param->output_batch_ = batch; - if (infered_flag) { - int in_h = inputs.front()->Height(); - int in_w = inputs.front()->Width(); - conv_param->input_channel_ = new_in_channel; - conv_param->output_channel_ = new_out_channel; - CheckIfUseWinograd(&use_winograd, &out_unit, conv_param); - in_shape = {batch, in_h, in_w, new_in_channel}; - out_shape = {batch, conv_param->output_h_, conv_param->output_w_, new_out_channel}; - } - std::vector filter_shape = {new_out_channel, conv_param->kernel_h_, conv_param->kernel_w_, new_in_channel}; - std::vector bias_shape = {new_out_channel}; - - // create sub kernels - std::vector group_convs; - for (int i = 0; i < group; ++i) { - std::vector new_inputs; - std::vector new_outputs; - auto new_conv_parameter = CreateNewConvParameter(conv_param); - if (new_conv_parameter == nullptr) { - FreeMemory(group_convs, new_inputs, new_outputs); - MS_LOG(ERROR) << "Get new conv parameter failed."; - return nullptr; - } - - // create new input for each group - auto in_tensor = CreateInputTensor(inputs.front()->data_type(), in_shape, infered_flag); - if (in_tensor == nullptr) { - delete new_conv_parameter; - FreeMemory(group_convs, new_inputs, new_outputs); - MS_LOG(ERROR) << "create input tensor failed."; - return nullptr; - } - new_inputs.emplace_back(in_tensor); - - // create new weight - int copy_length = conv_param->kernel_h_ * conv_param->kernel_w_ * new_in_channel * new_out_channel; - auto filter_tensor = - CreateFilterTensorFp32(inputs.at(kWeightIndex)->data_type(), filter_shape, inputs, copy_length, i); - if (filter_tensor == nullptr) { - delete new_conv_parameter; - FreeMemory(group_convs, new_inputs, new_outputs); - MS_LOG(ERROR) << "create filter tensor failed."; - return nullptr; - } - new_inputs.emplace_back(filter_tensor); - - // if has bias, create new bias - if (has_bias) { - auto bias_tensor = - CreateBiasTensorFp32(inputs.at(kBiasIndex)->data_type(), bias_shape, inputs, new_out_channel, i); - if (bias_tensor == nullptr) { - delete new_conv_parameter; - FreeMemory(group_convs, new_inputs, new_outputs); - MS_LOG(ERROR) << "create bias_tensor failed."; - return nullptr; - } - new_inputs.emplace_back(bias_tensor); - } - - // create new output tensor - for (size_t j = 0; j < outputs.size(); ++j) { - auto out_tensor = CreateOutputTensor(out_shape, outputs, infered_flag, j); - if (out_tensor == nullptr) { - delete new_conv_parameter; - FreeMemory(group_convs, new_inputs, new_outputs); - MS_LOG(ERROR) << "new out_tensor failed."; - return nullptr; - } - new_outputs.emplace_back(out_tensor); - } - group_convs.emplace_back(CpuConvFp32KernelSelect(new_inputs, new_outputs, - reinterpret_cast(new_conv_parameter), ctx, - primitive, use_winograd, out_unit)); - } - - return new (std::nothrow) - GroupConvolutionCPUKernel(op_parameter, inputs, outputs, ctx, primitive, group_convs, group); -} - -kernel::LiteKernel *CpuConvFp32KernelCreator(const std::vector &inputs, - const std::vector &outputs, OpParameter *op_parameter, - const InnerContext *ctx, const kernel::KernelKey &desc, - const mindspore::lite::PrimitiveC *primitive) { - MS_ASSERT(op_parameter != nullptr); - MS_ASSERT(desc.type == schema::PrimitiveType_Conv2D); - MS_ASSERT(desc.data_type == kNumberTypeFloat32); - auto conv_param = reinterpret_cast(op_parameter); - int group = conv_param->group_; - bool use_winograd = false; - int out_unit; - if (primitive != nullptr && primitive->infer_flag()) { - conv_param->input_h_ = inputs.front()->Height(); - conv_param->input_w_ = inputs.front()->Width(); - conv_param->input_channel_ = inputs.front()->Channel(); - conv_param->output_h_ = outputs.front()->Height(); - conv_param->output_w_ = outputs.front()->Width(); - conv_param->output_channel_ = outputs.front()->Channel(); - conv_param->op_parameter_.thread_num_ = ctx->thread_num_; - CheckIfUseWinograd(&use_winograd, &out_unit, conv_param); - } - - auto *weight_tensor = inputs.at(kWeightIndex); - auto *restore_data = weight_tensor->data_c(); - auto restore_type = weight_tensor->data_type(); - bool dequant_flag = - !weight_tensor->quant_params().empty() && weight_tensor->quant_params().front().inited && restore_data != nullptr; - if (dequant_flag) { - auto *dequant_weight = kernel::DequantUtil::DequantWeight(weight_tensor); - if (dequant_weight == nullptr) { - MS_LOG(ERROR) << "dequant data is nullptr."; - free(op_parameter); - return nullptr; - } - weight_tensor->set_data(dequant_weight); - } - - kernel::LiteKernel *kernel; - if (group == 1) { - kernel = CpuConvFp32KernelSelect(inputs, outputs, op_parameter, ctx, primitive, use_winograd, out_unit); - } else { - kernel = CpuGroupConvFp32KernelCreator(inputs, outputs, op_parameter, ctx, primitive, group); - } - - if (kernel == nullptr) { - MS_LOG(ERROR) << "kernel is nullptr."; - if (dequant_flag) { - weight_tensor->FreeData(); - weight_tensor->set_data(restore_data); - weight_tensor->set_data_type(restore_type); - } - free(op_parameter); - return nullptr; - } - auto ret = kernel->Init(); - if (ret != RET_OK && ret != RET_INFER_INVALID) { - MS_LOG(ERROR) << "Init kernel failed, name: " << op_parameter->name_ << ", type: " - << schema::EnumNamePrimitiveType(static_cast(op_parameter->type_)); - if (dequant_flag) { - weight_tensor->FreeData(); - weight_tensor->set_data(restore_data); - weight_tensor->set_data_type(restore_type); - } - delete kernel; - return nullptr; - } - - if (dequant_flag) { - weight_tensor->FreeData(); - weight_tensor->set_data(restore_data); - weight_tensor->set_data_type(restore_type); - } - - return kernel; -} - -REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Conv2D, CpuConvFp32KernelCreator) } // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_fp32.h b/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_fp32.h index b05c68f5eb..b5a4762f90 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_fp32.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_fp32.h @@ -28,8 +28,10 @@ class ConvolutionCPUKernel : public ConvolutionBaseCPUKernel { public: ConvolutionCPUKernel(OpParameter *parameter, const std::vector &inputs, const std::vector &outputs, const lite::InnerContext *ctx, - const mindspore::lite::PrimitiveC *primitive) - : ConvolutionBaseCPUKernel(parameter, inputs, outputs, ctx, primitive) {} + const mindspore::lite::PrimitiveC *primitive, float *origin_weight, float *origin_bias) + : ConvolutionBaseCPUKernel(parameter, inputs, outputs, ctx, primitive), + origin_weight_(origin_weight), + origin_bias_(origin_bias) {} ~ConvolutionCPUKernel() override { if (packed_weight_ != nullptr) { free(packed_weight_); @@ -57,20 +59,12 @@ class ConvolutionCPUKernel : public ConvolutionBaseCPUKernel { } protected: + float *origin_weight_; // do not free + float *origin_bias_; // do not free float *packed_weight_ = nullptr; float *packed_input_ = nullptr; float *col_major_input_ = nullptr; }; - -void FreeMemory(const std::vector &group_convs, const std::vector &new_inputs, - const std::vector &new_outputs); - -ConvParameter *CreateNewConvParameter(ConvParameter *parameter); - -lite::Tensor *CreateInputTensor(TypeId data_type, std::vector in_shape, bool infered_flag); - -lite::Tensor *CreateOutputTensor(std::vector out_shape, const std::vector &outputs, - bool infered_flag, int index); } // namespace mindspore::kernel #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONVOLUTION_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_winograd_fp32.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_winograd_fp32.cc index 33ffa0da34..c5b87ef621 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_winograd_fp32.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_winograd_fp32.cc @@ -81,8 +81,7 @@ int ConvolutionWinogradCPUKernel::InitWeightBias() { MS_LOG(ERROR) << "get matrix g from CookToomFilter failed."; return ret; } - auto weight_data = reinterpret_cast(filter_tensor->MutableData()); - ret = WinogradFilterTransform(weight_data, matrix_g, matrix_gt, oc_block); + ret = WinogradFilterTransform(origin_weight_, matrix_g, matrix_gt, oc_block); if (ret != RET_OK) { MS_LOG(ERROR) << "winograd filter transfrom failed."; return ret; @@ -97,8 +96,7 @@ int ConvolutionWinogradCPUKernel::InitWeightBias() { } memset(bias_data_, 0, new_bias_size); if (in_tensors_.size() == kInputSize2) { - auto ori_bias_addr = reinterpret_cast(in_tensors_.at(kBiasIndex)->MutableData()); - memcpy(bias_data_, ori_bias_addr, out_channel * sizeof(float)); + memcpy(bias_data_, origin_bias_, out_channel * sizeof(float)); } else { MS_ASSERT(in_tensors_.size() == kInputSize1); } @@ -171,10 +169,7 @@ int ConvolutionWinogradCPUKernel::Init() { MS_LOG(ERROR) << "Init weight bias failed."; return RET_ERROR; } - if (!InferShapeDone()) { - return RET_OK; - } - return ReSize(); + return RET_OK; } int ConvolutionWinogradCPUKernel::ReSize() { @@ -183,18 +178,11 @@ int ConvolutionWinogradCPUKernel::ReSize() { MS_LOG(ERROR) << "Resize is invalid."; return ret; } - ret = ConvolutionBaseCPUKernel::Init(); if (ret != RET_OK) { - MS_LOG(ERROR) << "ConvolutionBase init failed."; - return RET_ERROR; + MS_LOG(ERROR) << "conv base init failed."; + return ret; } - - kernel_unit_ = conv_param_->kernel_h_; - input_unit_ = output_unit_ + kernel_unit_ - 1; - conv_param_->input_unit_ = input_unit_; - conv_param_->output_unit_ = output_unit_; - ret = ConfigInputOutput(); if (ret != RET_OK) { MS_LOG(ERROR) << "ConfigInputOutput failed."; diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_winograd_fp32.h b/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_winograd_fp32.h index ee22d8bff0..6e9c26efaa 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_winograd_fp32.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_winograd_fp32.h @@ -28,10 +28,12 @@ class ConvolutionWinogradCPUKernel : public ConvolutionBaseCPUKernel { public: ConvolutionWinogradCPUKernel(OpParameter *parameter, const std::vector &inputs, const std::vector &outputs, const lite::InnerContext *ctx, - const mindspore::lite::PrimitiveC *primitive, int output_unit) + const mindspore::lite::PrimitiveC *primitive, int output_unit, float *origin_weight, + float *origin_bias) : ConvolutionBaseCPUKernel(parameter, inputs, outputs, ctx, primitive), output_unit_(output_unit), - trans_weight_(nullptr) {} + origin_weight_(origin_weight), + origin_bias_(origin_bias) {} ~ConvolutionWinogradCPUKernel() override { if (trans_weight_ != nullptr) { free(trans_weight_); @@ -69,6 +71,8 @@ class ConvolutionWinogradCPUKernel : public ConvolutionBaseCPUKernel { int kernel_unit_; int input_unit_; int output_unit_; + float *origin_weight_; // do not free + float *origin_bias_; // do not free float *tmp_data_ = nullptr; float *trans_input_ = nullptr; float *gemm_out_ = nullptr; diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/group_convolution_fp32.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/group_convolution_fp32.cc index 2624efb960..eb86a433df 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/group_convolution_fp32.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/group_convolution_fp32.cc @@ -92,11 +92,8 @@ int GroupConvolutionCPUKernel::PreProcess() { std::vector out_shape; for (int i = 0; i < group_num_; ++i) { // in - int in_batch = conv_param_->input_batch_; - int in_h = conv_param_->input_h_; - int in_w = conv_param_->input_w_; - int in_c = conv_param_->input_channel_; - in_shape = {in_batch, in_h, in_w, in_c}; + auto in_tensor = in_tensors_.front(); + in_shape = {in_tensor->Batch(), in_tensor->Height(), in_tensor->Width(), conv_param_->input_channel_}; auto sub_kernel_in_tensor = group_convs_.at(i)->in_tensors().front(); sub_kernel_in_tensor->set_shape(in_shape); ret = sub_kernel_in_tensor->MallocData(); @@ -106,11 +103,8 @@ int GroupConvolutionCPUKernel::PreProcess() { return ret; } // out - int out_batch = conv_param_->output_batch_; - int out_h = conv_param_->output_h_; - int out_w = conv_param_->output_w_; - int out_c = conv_param_->output_channel_; - out_shape = {out_batch, out_h, out_w, out_c}; + auto out_tensor = out_tensors_.front(); + out_shape = {out_tensor->Batch(), out_tensor->Height(), out_tensor->Width(), conv_param_->output_channel_}; auto sub_kernel_out_tensors = group_convs_.at(i)->out_tensors(); for (auto tensor : sub_kernel_out_tensors) { tensor->set_shape(out_shape); @@ -143,7 +137,8 @@ int GroupConvolutionCPUKernel::PreProcess() { } void GroupConvolutionCPUKernel::SeparateInput(int group_id) { - int in_plane = conv_param_->input_h_ * conv_param_->input_w_ * conv_param_->input_batch_; + auto in_tensor = in_tensors_.front(); + int in_plane = in_tensor->Height() * in_tensor->Width() * in_tensor->Batch(); int sub_in_channel = conv_param_->input_channel_; int ori_in_channel = sub_in_channel * group_num_; auto sub_in_data = reinterpret_cast(group_convs_.at(group_id)->in_tensors().front()->data_c()); @@ -157,7 +152,8 @@ void GroupConvolutionCPUKernel::SeparateInput(int group_id) { } void GroupConvolutionCPUKernel::PostConcat(int group_id) { - int out_plane = conv_param_->output_h_ * conv_param_->output_w_ * conv_param_->output_batch_; + auto out_tensor = out_tensors_.front(); + int out_plane = out_tensor->Height() * out_tensor->Width() * out_tensor->Batch(); int sub_out_channel = conv_param_->output_channel_; int ori_out_channel = sub_out_channel * group_num_; auto sub_out_data = reinterpret_cast(group_convs_.at(group_id)->out_tensors().front()->data_c()); diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/convolution_int8.cc b/mindspore/lite/src/runtime/kernel/arm/int8/convolution_int8.cc index c5242cf534..616778ff39 100644 --- a/mindspore/lite/src/runtime/kernel/arm/int8/convolution_int8.cc +++ b/mindspore/lite/src/runtime/kernel/arm/int8/convolution_int8.cc @@ -19,8 +19,7 @@ #include "nnacl/int8/conv_int8.h" #include "schema/model_generated.h" #include "src/kernel_registry.h" -#include "src/runtime/kernel/arm/base/layout_transform.h" -#include "src/runtime/kernel/arm/fp32/convolution_fp32.h" +#include "src/runtime/kernel/arm/fp32/convolution_delegate_fp32.h" #include "src/runtime/kernel/arm/int8/convolution_1x1_int8.h" #include "src/runtime/kernel/arm/int8/convolution_3x3_int8.h" #include "src/runtime/kernel/arm/int8/group_convolution_int8.h" diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/conv1x1_fp32_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/conv1x1_fp32_tests.cc index 71c037bd6e..1f2629f33c 100644 --- a/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/conv1x1_fp32_tests.cc +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/conv1x1_fp32_tests.cc @@ -139,146 +139,4 @@ TEST_F(TestConv1x1Fp32, Input1x1PrePack4) { EXPECT_EQ(0, CompareOutputData(out, correct, 54)); delete conv_param; } - -int Conv1x1TestInit1(std::vector *inputs_, std::vector *outputs_, - ConvParameter *conv_param, float **correct) { - auto *in_t = new lite::Tensor(kNumberTypeFloat, {1, 2, 3, 4}, schema::Format_NHWC, lite::Tensor::VAR); - in_t->MallocData(); - float in[] = {12.216284, 3.3466918, 15.327419, 5.234958, 0.804376, 9.952188, 14.727955, -8.080715, - 13.71383, 8.055829, 6.5845337, -9.25232, -4.24519, 11.550042, 9.262012, 1.2780352, - 6.7263746, -3.9301445, 3.764492, -8.602078, -3.3558068, 13.619035, -2.6694393, 3.2008505}; - memcpy(in_t->MutableData(), in, sizeof(float) * 24); - inputs_->push_back(in_t); - - auto *weight_t = new lite::Tensor(kNumberTypeFloat, {3, 1, 1, 4}, schema::Format_NHWC, lite::Tensor::CONST_TENSOR); - weight_t->MallocData(); - float weight[] = {-0.7308652, 0.5257509, -0.87825793, -1.123181, -1.2206168, 0.562695, - 1.5382664, -0.5020635, 0.8591602, -0.26410004, 1.1262615, 0.073132955}; /* nhwc */ - memcpy(weight_t->MutableData(), weight, sizeof(float) * 12); - inputs_->push_back(weight_t); - - auto *bias_t = new lite::Tensor(kNumberTypeFloat, {3}, schema::Format_NHWC, lite::Tensor::CONST_TENSOR); - bias_t->MallocData(); - float bias[] = {2, 2, 2}; - memcpy(bias_t->MutableData(), bias, sizeof(float) * 3); - inputs_->push_back(bias_t); - - auto *out_t = new lite::Tensor(kNumberTypeFloat, {1, 2, 3, 3}, schema::Format_NHWC, lite::Tensor::VAR); - out_t->MallocData(); - outputs_->push_back(out_t); - - *correct = reinterpret_cast(malloc(out_t->ElementsNum() * sizeof(float))); - float co[] = {2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 1.3731456, 1.6877825, 12.427691, 2., 2., 2.}; - memcpy(*correct, co, out_t->ElementsNum() * sizeof(float)); - - conv_param->kernel_h_ = conv_param->kernel_w_ = 1; - conv_param->stride_h_ = conv_param->stride_w_ = 2; - conv_param->dilation_h_ = conv_param->dilation_w_ = 1; - conv_param->pad_u_ = conv_param->pad_l_ = 1; - conv_param->act_type_ = ActType_No; - return out_t->ElementsNum(); -} - -TEST_F(TestConv1x1Fp32, Conv1x1Test1) { - std::vector inputs_; - std::vector outputs_; - auto conv_param = new ConvParameter(); - auto *ctx = new lite::InnerContext(); - ctx->thread_num_ = 1; - ASSERT_EQ(lite::RET_OK, ctx->Init()); - float *correct; - int total_size = Conv1x1TestInit1(&inputs_, &outputs_, conv_param, &correct); - auto *conv1x1 = - new kernel::Convolution1x1CPUKernel(reinterpret_cast(conv_param), inputs_, outputs_, ctx, nullptr); - - conv1x1->Init(); - conv1x1->Run(); - - ASSERT_EQ(0, CompareOutputData(reinterpret_cast(outputs_[0]->MutableData()), correct, total_size, 0.0001)); - delete conv_param; - delete conv1x1; - for (auto t : inputs_) delete t; - for (auto t : outputs_) delete t; - free(correct); -} - -int Conv1x1TestInit2(std::vector *inputs_, std::vector *outputs_, - ConvParameter *conv_param, float **correct) { - size_t buffer_size; - auto *in_t = new lite::Tensor(kNumberTypeFloat, {1, 300, 300, 24}, schema::Format_NHWC, lite::Tensor::VAR); - in_t->MallocData(); - std::string input_path = "./conv/conv1x1fp32_input1_nhwc.bin"; - auto in = reinterpret_cast(mindspore::lite::ReadFile(input_path.c_str(), &buffer_size)); - memcpy(in_t->MutableData(), in, buffer_size); - inputs_->push_back(in_t); - - auto *weight_t = new lite::Tensor(kNumberTypeFloat, {40, 1, 1, 24}, schema::Format_NHWC, lite::Tensor::CONST_TENSOR); - weight_t->MallocData(); - std::string weight_path = "./conv/conv1x1fp32_weight1_nhwc.bin"; - auto weight = reinterpret_cast(mindspore::lite::ReadFile(weight_path.c_str(), &buffer_size)); - memcpy(weight_t->MutableData(), weight, buffer_size); - inputs_->push_back(weight_t); - - auto *bias_t = new lite::Tensor(kNumberTypeFloat, {40}, schema::Format_NHWC, lite::Tensor::CONST_TENSOR); - bias_t->MallocData(); - std::string bias_path = "./conv/conv1x1fp32_bias1_nhwc.bin"; - auto bias = mindspore::lite::ReadFile(bias_path.c_str(), &buffer_size); - memcpy(bias_t->MutableData(), bias, buffer_size); - inputs_->push_back(bias_t); - - auto *out_t = new lite::Tensor(kNumberTypeFloat, {1, 300, 300, 40}, schema::Format_NHWC, lite::Tensor::VAR); - out_t->MallocData(); - outputs_->push_back(out_t); - - std::string out_path = "./conv/conv1x1fp32_output1_nhwc.bin"; - auto out_nhwc = mindspore::lite::ReadFile(out_path.c_str(), &buffer_size); - *correct = reinterpret_cast(malloc(buffer_size)); - memcpy(*correct, out_nhwc, buffer_size); - - conv_param->kernel_h_ = conv_param->kernel_w_ = 1; - conv_param->stride_h_ = conv_param->stride_w_ = 1; - conv_param->dilation_h_ = conv_param->dilation_w_ = 1; - conv_param->pad_u_ = conv_param->pad_l_ = 0; - conv_param->act_type_ = ActType_No; - return out_t->ElementsNum(); -} - -TEST_F(TestConv1x1Fp32, Conv1x1Test2) { - std::vector inputs_; - std::vector outputs_; - auto conv_param = new ConvParameter(); - auto *ctx = new lite::InnerContext(); - ctx->thread_num_ = 2; - ASSERT_EQ(lite::RET_OK, ctx->Init()); - float *correct; - int total_size = Conv1x1TestInit2(&inputs_, &outputs_, conv_param, &correct); - auto *conv1x1 = - new kernel::Convolution1x1CPUKernel(reinterpret_cast(conv_param), inputs_, outputs_, ctx, nullptr); - - conv1x1->Init(); - conv1x1->Run(); - ASSERT_EQ(0, CompareOutputData(reinterpret_cast(outputs_[0]->MutableData()), correct, total_size, 0.0001)); - - /* running warm up */ - for (int i = 0; i < 0; i++) { - conv1x1->Run(); - } - - /* running time cost */ - int loop_count = 1; - auto time_start = mindspore::lite::GetTimeUs(); - for (int i = 0; i < loop_count; i++) { - conv1x1->Run(); - } - auto time_end = mindspore::lite::GetTimeUs(); - auto cost = time_end - time_start; - uint64_t time_avg = cost / loop_count; - printf("1x1 average time : %f ms\n", time_avg / 1000.0f); - - delete conv_param; - delete conv1x1; - for (auto t : inputs_) delete t; - for (auto t : outputs_) delete t; - free(correct); -} } // namespace mindspore