From f7ee19a59008c715e59a4ba786ed0387c278cdf4 Mon Sep 17 00:00:00 2001 From: zhanyuan Date: Thu, 20 Aug 2020 21:24:37 +0800 Subject: [PATCH] 1.Fix bugs of some InferShape. 2.Fix the bug of fc int8 --- mindspore/lite/src/ops/dedepthwise_conv2d.cc | 4 +- mindspore/lite/src/ops/full_connection.cc | 51 ++++++++++++------- .../kernel/arm/int8/fullconnection_int8.cc | 4 +- 3 files changed, 37 insertions(+), 22 deletions(-) diff --git a/mindspore/lite/src/ops/dedepthwise_conv2d.cc b/mindspore/lite/src/ops/dedepthwise_conv2d.cc index 53881c51a3..8c63bb0ea1 100644 --- a/mindspore/lite/src/ops/dedepthwise_conv2d.cc +++ b/mindspore/lite/src/ops/dedepthwise_conv2d.cc @@ -141,8 +141,8 @@ int DeDepthwiseConv2D::InferShape(std::vector inputs_, pad_u_ = GetPadUp(); pad_d_ = GetPadDown(); pad_r_ = GetPadRight(); - output_h = GetStrideH() * (input_h - 1) * GetKernelH() - pad_u_ - pad_d_; - output_w = GetStrideW() * (input_w - 1) * GetKernelW() - pad_l_ - pad_r_; + output_h = GetStrideH() * (input_h - 1) + GetKernelH() - pad_u_ - pad_d_; + output_w = GetStrideW() * (input_w - 1) + GetKernelW() - pad_l_ - pad_r_; if ((output_h + GetPadUp() + GetPadDown() - GetKernelH()) % GetStrideH() != 0) { output_h += (output_h + GetPadLeft() + GetPadRight() - GetKernelH()) % GetStrideH(); } diff --git a/mindspore/lite/src/ops/full_connection.cc b/mindspore/lite/src/ops/full_connection.cc index 20b480fa19..077c2d7e4f 100644 --- a/mindspore/lite/src/ops/full_connection.cc +++ b/mindspore/lite/src/ops/full_connection.cc @@ -28,7 +28,7 @@ void FullConnection::SetHasBias(bool has_bias) { this->primitive->value.AsFullCo void FullConnection::SetAxis(int axis) { this->primitive->value.AsFullConnection()->axis = axis; } void FullConnection::SetUseAxis(bool use_axis) { this->primitive->value.AsFullConnection()->useAxis = use_axis; } void FullConnection::SetActivationType(int activationType) { - this->primitive->value.AsFullConnection()->activationType = (schema::ActivationType) activationType; + this->primitive->value.AsFullConnection()->activationType = (schema::ActivationType)activationType; } #else @@ -47,43 +47,58 @@ int FullConnection::InferShape(std::vector inputs_, MS_ASSERT(this->primitive != nullptr); auto input0 = inputs_.front(); MS_ASSERT(input0 != nullptr); - auto input1 = inputs_.at(1); + auto input1 = inputs_[1]; MS_ASSERT(input1 != nullptr); auto output = outputs_.front(); MS_ASSERT(output != nullptr); - output->set_data_type(input0->data_type()); - output->SetFormat(input0->GetFormat()); if (!GetInferFlag()) { return RET_OK; } if ((GetHasBias() && inputs_.size() != kMultiNum) || (!GetHasBias() && inputs_.size() != kDoubleNum)) { MS_LOG(ERROR) << "Input tensors num error"; - return 1; + return RET_INPUT_TENSOR_ERROR; } - if (GetAxis() < 1 || GetAxis() > static_cast(input0->shape().size())) { + if (GetUseAxis() && (GetAxis() < 1 || GetAxis() > static_cast(input0->shape().size()))) { MS_LOG(ERROR) << "FullConnection axis invalid"; - return 1; + return RET_ERROR; } int new_k = 1; - for (size_t i = GetAxis(); i < input0->shape().size(); ++i) { - new_k *= input0->shape().at(i); - } - if (new_k != input1->shape().at(1)) { - MS_LOG(ERROR) << "Input1 size invalid"; - return 1; + if (GetUseAxis()) { + for (int i = GetAxis(); i < input0->shape().size(); ++i) { + new_k *= input0->shape()[i]; + } + if (new_k != input1->shape()[1]) { + MS_LOG(ERROR) << "Input1 size invalid"; + return RET_INPUT_TENSOR_ERROR; + } + } else { + new_k = input1->shape()[1]; } if (GetHasBias()) { - if (inputs_.at(2)->shape()[0] != input1->shape()[0]) { + if (inputs_[2]->shape()[0] != input1->shape()[0]) { MS_LOG(ERROR) << "bias size invalid"; - return 1; + return RET_INPUT_TENSOR_ERROR; } } std::vector out_shape{inputs_[0]->shape()}; - out_shape.resize(GetAxis() + 1); - out_shape[GetAxis()] = input1->shape()[0]; + if (GetUseAxis()) { + out_shape.resize(GetAxis() + 1); + out_shape[GetAxis()] = input1->shape()[0]; + } else { + int total = 1; + for (int i = 0; i < input0->shape().size(); ++i) { + total *= input0->shape()[i]; + } + out_shape.resize(2); + auto batch_size = total / new_k; + out_shape[0] = batch_size; + out_shape[1] = input1->shape()[0]; + } output->set_shape(out_shape); + output->set_data_type(input0->data_type()); + output->SetFormat(input0->GetFormat()); - return 0; + return RET_OK; } } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/fullconnection_int8.cc b/mindspore/lite/src/runtime/kernel/arm/int8/fullconnection_int8.cc index e41f7b56b9..45e2809221 100644 --- a/mindspore/lite/src/runtime/kernel/arm/int8/fullconnection_int8.cc +++ b/mindspore/lite/src/runtime/kernel/arm/int8/fullconnection_int8.cc @@ -91,8 +91,8 @@ int FullconnectionInt8CPUKernel::ReSize() { QuantizeRoundParameter(real_multiplier, &quant_params_.quant_multiplier, &quant_params_.left_shift, &quant_params_.right_shift); CalculateActivationRangeQuantized(fc_param_->act_type_ == ActType_Relu, fc_param_->act_type_ == ActType_Relu6, - quant_params_.output.zp_, quant_params_.output.scale_, &quant_params_.out_act_max, - &quant_params_.out_act_min); + quant_params_.output.zp_, quant_params_.output.scale_, &quant_params_.out_act_min, + &quant_params_.out_act_max); return RET_OK; }