From ba4dec43d492ca2d46d2b9625001da9d333f03c2 Mon Sep 17 00:00:00 2001 From: zhanyuan Date: Thu, 13 Aug 2020 10:44:01 +0800 Subject: [PATCH] Fix fc op's bug --- mindspore/lite/schema/ops.fbs | 1 + mindspore/lite/src/ops/fullconnection.cc | 40 ++++++++++++++----- mindspore/lite/src/populate_parameter.cc | 9 ++++- .../kernel/arm/fp32/power_fp32_tests.cc | 30 ++++++++++---- .../tflite/tflite_fullyconnected_parser.cc | 7 ++++ 5 files changed, 67 insertions(+), 20 deletions(-) diff --git a/mindspore/lite/schema/ops.fbs b/mindspore/lite/schema/ops.fbs index f07fccd0df..cbb07b0be8 100644 --- a/mindspore/lite/schema/ops.fbs +++ b/mindspore/lite/schema/ops.fbs @@ -352,6 +352,7 @@ table FullConnection { hasBias: bool; axis: int; useAxis: bool; + activationType: ActivationType = 0; } // Mean(input_tensor, axis, keep_dims) diff --git a/mindspore/lite/src/ops/fullconnection.cc b/mindspore/lite/src/ops/fullconnection.cc index 7b4b1e051f..4e32bc66d3 100644 --- a/mindspore/lite/src/ops/fullconnection.cc +++ b/mindspore/lite/src/ops/fullconnection.cc @@ -24,7 +24,7 @@ int FullConnection::InferShape(std::vector inputs_, std::vecto MS_ASSERT(this->primitive != nullptr); auto input0 = inputs_.front(); MS_ASSERT(input0 != nullptr); - auto input1 = inputs_.at(1); + auto input1 = inputs_[1]; MS_ASSERT(input1 != nullptr); auto output = outputs_.front(); MS_ASSERT(output != nullptr); @@ -33,27 +33,45 @@ int FullConnection::InferShape(std::vector inputs_, std::vecto MS_LOG(ERROR) << "Input tensors num error"; return RET_INPUT_TENSOR_ERROR; } - if (fc_prim->axis() < 1 || fc_prim->axis() > input0->shape().size()) { + auto axis = fc_prim->axis(); + auto use_axis = fc_prim->useAxis(); + if (use_axis && (axis < 1 || axis >= input0->shape().size())) { MS_LOG(ERROR) << "FullConnection axis invalid"; return RET_INPUT_TENSOR_ERROR; } int new_k = 1; - for (size_t i = fc_prim->axis(); i < input0->shape().size(); ++i) { - new_k *= input0->shape().at(i); - } - if (new_k != input1->shape().at(1)) { - MS_LOG(ERROR) << "Input1 size invalid"; - return RET_PARAM_INVALID; + if (use_axis) { + for (int i = axis; i < input0->shape().size(); ++i) { + new_k *= input0->shape()[i]; + } + if (new_k != input1->shape()[1]) { + MS_LOG(ERROR) << "Input1 size invalid"; + return RET_PARAM_INVALID; + } + } else { + new_k = input1->shape()[1]; } + if (fc_prim->hasBias()) { - if (inputs_.at(2)->shape()[0] != input1->shape()[0]) { + if (inputs_[2]->shape()[0] != input1->shape()[0]) { MS_LOG(ERROR) << "bias size invalid"; return RET_PARAM_INVALID; } } std::vector out_shape{inputs_[0]->shape()}; - out_shape.resize(fc_prim->axis() + 1); - out_shape[fc_prim->axis()] = input1->shape()[0]; + if (use_axis) { + out_shape.resize(fc_prim->axis() + 1); + out_shape[fc_prim->axis()] = input1->shape()[0]; + } else { + int total = 1; + for (int i = 0; i < input0->shape().size(); ++i) { + total *= input0->shape()[i]; + } + out_shape.resize(2); + auto batch_size = total / new_k; + out_shape[0] = batch_size; + out_shape[1] = input1->shape()[0]; + } output->set_shape(out_shape); output->set_data_type(input0->data_type()); output->SetFormat(input0->GetFormat()); diff --git a/mindspore/lite/src/populate_parameter.cc b/mindspore/lite/src/populate_parameter.cc index 25ceaeb960..80a00476cf 100644 --- a/mindspore/lite/src/populate_parameter.cc +++ b/mindspore/lite/src/populate_parameter.cc @@ -226,7 +226,14 @@ OpParameter *PopulateFullconnectionParameter(const lite::Primitive *primitive) { matmul_param->b_transpose_ = true; matmul_param->a_transpose_ = false; matmul_param->has_bias_ = param->hasBias(); - matmul_param->act_type_ = ActType_No; + if (param->activationType() == schema::ActivationType_RELU) { + matmul_param->act_type_ = ActType_Relu; + } else if (param->activationType() == schema::ActivationType_RELU6) { + matmul_param->act_type_ = ActType_Relu6; + } else { + matmul_param->act_type_ = ActType_No; + } + return reinterpret_cast(matmul_param); } diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/power_fp32_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/power_fp32_tests.cc index 49e4e01e5b..1e53abf055 100644 --- a/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/power_fp32_tests.cc +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/power_fp32_tests.cc @@ -48,6 +48,22 @@ int PowerTestInit(std::vector *inputs_, std::vectorElementsNum(); } +int PowerTestInit2(std::vector *inputs_, std::vector *outputs_, + float *a_ptr, std::vector a_shape, std::vector c_shape) { + auto in_t = + new lite::tensor::Tensor(kNumberTypeFloat, a_shape, schema::Format_NHWC, static_cast(1)); + in_t->MallocData(); + memcpy(in_t->Data(), a_ptr, sizeof(float) * in_t->ElementsNum()); + inputs_->push_back(in_t); + + auto out_t = + new lite::tensor::Tensor(kNumberTypeFloat, c_shape, schema::Format_NHWC, static_cast(1)); + out_t->MallocData(); + outputs_->push_back(out_t); + + return out_t->ElementsNum(); +} + TEST_F(TestPowerFp32, Simple) { std::vector inputs_; std::vector outputs_; @@ -62,13 +78,12 @@ TEST_F(TestPowerFp32, Simple) { int total_size = PowerTestInit(&inputs_, &outputs_, a, b, a_shape, b_shape, c_shape); auto ctx = new lite::Context; ctx->thread_num_ = 1; - kernel::PowerCPUKernel *op = new kernel::PowerCPUKernel(reinterpret_cast(param), inputs_, outputs_, - ctx, nullptr); + kernel::PowerCPUKernel *op = + new kernel::PowerCPUKernel(reinterpret_cast(param), inputs_, outputs_, ctx, nullptr); op->Init(); op->Run(); float correct[] = {1, 64, 2187, 65536}; float *output = reinterpret_cast(outputs_[0]->Data()); - for (int i = 0; i < 4; ++i) printf("%f ", output[i]); CompareOutputData(reinterpret_cast(outputs_[0]->Data()), correct, total_size, 0.0001); delete op; for (auto t : inputs_) delete t; @@ -79,18 +94,17 @@ TEST_F(TestPowerFp32, Broadcast) { std::vector inputs_; std::vector outputs_; auto param = new PowerParameter(); + param->power_ = 2; param->scale_ = 1; param->shift_ = 0; float a[] = {1, 2, 3, 4}; - float b[] = {2}; std::vector a_shape = {2, 2}; - std::vector b_shape = {1}; std::vector c_shape = {2, 2}; - int total_size = PowerTestInit(&inputs_, &outputs_, a, b, a_shape, b_shape, c_shape); + int total_size = PowerTestInit2(&inputs_, &outputs_, a, a_shape, c_shape); auto ctx = new lite::Context; ctx->thread_num_ = 2; - kernel::PowerCPUKernel *op = new kernel::PowerCPUKernel(reinterpret_cast(param), inputs_, outputs_, - ctx, nullptr); + kernel::PowerCPUKernel *op = + new kernel::PowerCPUKernel(reinterpret_cast(param), inputs_, outputs_, ctx, nullptr); op->Init(); op->Run(); float correct[] = {1, 4, 9, 16}; diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_fullyconnected_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_fullyconnected_parser.cc index 9a6341b148..7f039d0367 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_fullyconnected_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_fullyconnected_parser.cc @@ -38,6 +38,13 @@ STATUS TfliteFullyConnectedParser::Parse(const std::unique_ptr attr(new schema::FullConnectionT()); + const auto &tflite_attr = tfliteOp->builtin_options.AsFullyConnectedOptions(); + if (tflite_attr == nullptr) { + MS_LOG(ERROR) << "get op: " << op->name << " attr failed"; + return RET_NULL_PTR; + } + attr->activationType = GetActivationFunctionType(tflite_attr->fused_activation_function); + auto weight_index = tfliteOp->inputs[1]; const auto &weight_tensor = tfliteTensors[weight_index]; if (weight_tensor == nullptr) {