From 815b7af9ecc6e1687a8627f074c67f703eb7891f Mon Sep 17 00:00:00 2001 From: gongdaguo Date: Wed, 25 Nov 2020 22:06:39 +0800 Subject: [PATCH] fix arithmetic compare, matmul, logicalnot, constant_folding_fusion --- mindspore/lite/src/ops/arithmetic_compare.cc | 33 +++++ mindspore/lite/src/ops/arithmetic_compare.h | 41 ++++++ mindspore/lite/src/ops/equal.cc | 10 -- mindspore/lite/src/ops/equal.h | 9 +- mindspore/lite/src/ops/greater.cc | 10 -- mindspore/lite/src/ops/greater.h | 9 +- mindspore/lite/src/ops/greater_equal.cc | 10 -- mindspore/lite/src/ops/greater_equal.h | 9 +- mindspore/lite/src/ops/less.cc | 10 -- mindspore/lite/src/ops/less.h | 9 +- mindspore/lite/src/ops/less_equal.cc | 10 -- mindspore/lite/src/ops/less_equal.h | 9 +- mindspore/lite/src/ops/matmul.cc | 20 ++- mindspore/lite/src/ops/not_equal.cc | 10 -- mindspore/lite/src/ops/not_equal.h | 9 +- .../src/ops/populate/arithmetic_populate.cc | 4 +- .../arm/fp32/arithmetic_compare_fp32.cc | 123 +++++++++--------- .../kernel/arm/fp32/arithmetic_compare_fp32.h | 44 ++++++- .../kernel/arm/fp32/arithmetic_fp32.cc | 9 ++ .../runtime/kernel/arm/fp32/arithmetic_fp32.h | 10 +- .../kernel/arm/int8/arithmetic_self_int8.cc | 1 + .../converter/parser/onnx/onnx_conv_parser.cc | 20 +-- .../fusion/constant_folding_fusion.cc | 9 -- 23 files changed, 244 insertions(+), 184 deletions(-) create mode 100644 mindspore/lite/src/ops/arithmetic_compare.cc create mode 100644 mindspore/lite/src/ops/arithmetic_compare.h diff --git a/mindspore/lite/src/ops/arithmetic_compare.cc b/mindspore/lite/src/ops/arithmetic_compare.cc new file mode 100644 index 0000000000..661883b98a --- /dev/null +++ b/mindspore/lite/src/ops/arithmetic_compare.cc @@ -0,0 +1,33 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/ops/arithmetic_compare.h" + +namespace mindspore { +namespace lite { + +int ArithmeticCompare::InferShape(std::vector inputs_, std::vector outputs_) { + auto res = Arithmetic::InferShape(inputs_, outputs_); + if (res == RET_OK) { + auto output = outputs_.front(); + output->set_data_type(TypeId::kNumberTypeBool); + return RET_OK; + } else { + return res; + } +} +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/src/ops/arithmetic_compare.h b/mindspore/lite/src/ops/arithmetic_compare.h new file mode 100644 index 0000000000..4917a61792 --- /dev/null +++ b/mindspore/lite/src/ops/arithmetic_compare.h @@ -0,0 +1,41 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LITE_MINDSPORE_LITE_C_OPS_ARITHMETIC_COMPARE_H_ +#define LITE_MINDSPORE_LITE_C_OPS_ARITHMETIC_COMPARE_H_ + +#include +#include +#include + +#include "src/ops/arithmetic.h" + +namespace mindspore { +namespace lite { +class ArithmeticCompare : public Arithmetic { + public: + ArithmeticCompare() = default; + ~ArithmeticCompare() = default; +#ifdef PRIMITIVE_WRITEABLE + MS_DECLARE_PARENT(ArithmeticCompare, Arithmetic); + explicit ArithmeticCompare(schema::PrimitiveT *primitive) : Arithmetic(primitive) {} +#endif + int InferShape(std::vector inputs_, std::vector outputs_) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // LITE_MINDSPORE_LITE_C_OPS_ARITHMETIC_COMPARE_H_ diff --git a/mindspore/lite/src/ops/equal.cc b/mindspore/lite/src/ops/equal.cc index dd55c8f265..29d91ec01a 100644 --- a/mindspore/lite/src/ops/equal.cc +++ b/mindspore/lite/src/ops/equal.cc @@ -35,16 +35,6 @@ int Equal::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers:: PrimitiveC *EqualCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC(primitive); } Registry EqualRegistry(schema::PrimitiveType_Equal, EqualCreator); #endif -int Equal::InferShape(std::vector inputs_, std::vector outputs_) { - auto input = inputs_.front(); - MS_ASSERT(input != nullptr); - auto output = outputs_.front(); - MS_ASSERT(output != nullptr); - output->set_shape(input->shape()); - output->set_data_type(TypeId::kNumberTypeBool); - output->set_format(input->format()); - return RET_OK; -} } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/src/ops/equal.h b/mindspore/lite/src/ops/equal.h index 9d88d92853..4eb0efe8ff 100644 --- a/mindspore/lite/src/ops/equal.h +++ b/mindspore/lite/src/ops/equal.h @@ -20,21 +20,20 @@ #include #include #include -#include "src/ops/arithmetic.h" +#include "src/ops/arithmetic_compare.h" namespace mindspore { namespace lite { -class Equal : public Arithmetic { +class Equal : public ArithmeticCompare { public: Equal() = default; ~Equal() = default; #ifdef PRIMITIVE_WRITEABLE - MS_DECLARE_PARENT(Equal, PrimitiveC); - explicit Equal(schema::PrimitiveT *primitive) : Arithmetic(primitive) {} + MS_DECLARE_PARENT(Equal, ArithmeticCompare); + explicit Equal(schema::PrimitiveT *primitive) : ArithmeticCompare(primitive) {} #else int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; #endif - int InferShape(std::vector inputs_, std::vector outputs_) override; }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/src/ops/greater.cc b/mindspore/lite/src/ops/greater.cc index f6df714250..2950e175a0 100644 --- a/mindspore/lite/src/ops/greater.cc +++ b/mindspore/lite/src/ops/greater.cc @@ -36,16 +36,6 @@ int Greater::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers PrimitiveC *GreaterCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC(primitive); } Registry GreaterRegistry(schema::PrimitiveType_Greater, GreaterCreator); #endif -int Greater::InferShape(std::vector inputs_, std::vector outputs_) { - auto input = inputs_.front(); - MS_ASSERT(input != nullptr); - auto output = outputs_.front(); - MS_ASSERT(output != nullptr); - output->set_shape(input->shape()); - output->set_data_type(TypeId::kNumberTypeBool); - output->set_format(input->format()); - return RET_OK; -} } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/src/ops/greater.h b/mindspore/lite/src/ops/greater.h index 61a7773f03..c7de708ec2 100644 --- a/mindspore/lite/src/ops/greater.h +++ b/mindspore/lite/src/ops/greater.h @@ -20,21 +20,20 @@ #include #include -#include "src/ops/arithmetic.h" +#include "src/ops/arithmetic_compare.h" namespace mindspore { namespace lite { -class Greater : public Arithmetic { +class Greater : public ArithmeticCompare { public: Greater() = default; ~Greater() = default; #ifdef PRIMITIVE_WRITEABLE - MS_DECLARE_PARENT(Greater, Arithmetic); - explicit Greater(schema::PrimitiveT *primitive) : Arithmetic(primitive) {} + MS_DECLARE_PARENT(Greater, ArithmeticCompare); + explicit Greater(schema::PrimitiveT *primitive) : ArithmeticCompare(primitive) {} #else int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; #endif - int InferShape(std::vector inputs_, std::vector outputs_) override; }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/src/ops/greater_equal.cc b/mindspore/lite/src/ops/greater_equal.cc index 7cbf55237d..e7dd799802 100644 --- a/mindspore/lite/src/ops/greater_equal.cc +++ b/mindspore/lite/src/ops/greater_equal.cc @@ -38,16 +38,6 @@ PrimitiveC *GreaterEqualCreator(const schema::Primitive *primitive) { Registry GreaterEqualRegistry(schema::PrimitiveType_GreaterEqual, GreaterEqualCreator); #endif -int GreaterEqual::InferShape(std::vector inputs_, std::vector outputs_) { - auto input = inputs_.front(); - MS_ASSERT(input != nullptr); - auto output = outputs_.front(); - MS_ASSERT(output != nullptr); - output->set_shape(input->shape()); - output->set_data_type(TypeId::kNumberTypeBool); - output->set_format(input->format()); - return RET_OK; -} } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/src/ops/greater_equal.h b/mindspore/lite/src/ops/greater_equal.h index 1bd8f5bcc4..f8df62e2fa 100644 --- a/mindspore/lite/src/ops/greater_equal.h +++ b/mindspore/lite/src/ops/greater_equal.h @@ -21,21 +21,20 @@ #include #include -#include "src/ops/arithmetic.h" +#include "src/ops/arithmetic_compare.h" namespace mindspore { namespace lite { -class GreaterEqual : public Arithmetic { +class GreaterEqual : public ArithmeticCompare { public: GreaterEqual() = default; ~GreaterEqual() = default; #ifdef PRIMITIVE_WRITEABLE - MS_DECLARE_PARENT(GreaterEqual, Arithmetic); - explicit GreaterEqual(schema::PrimitiveT *primitive) : Arithmetic(primitive) {} + MS_DECLARE_PARENT(GreaterEqual, ArithmeticCompare); + explicit GreaterEqual(schema::PrimitiveT *primitive) : ArithmeticCompare(primitive) {} #else int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; #endif - int InferShape(std::vector inputs_, std::vector outputs_) override; }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/src/ops/less.cc b/mindspore/lite/src/ops/less.cc index 25650a879a..fe4d82ee76 100644 --- a/mindspore/lite/src/ops/less.cc +++ b/mindspore/lite/src/ops/less.cc @@ -38,16 +38,6 @@ PrimitiveC *LessCreator(const schema::Primitive *primitive) { return PrimitiveC: Registry LessRegistry(schema::PrimitiveType_Less, LessCreator); #endif -int Less::InferShape(std::vector inputs_, std::vector outputs_) { - auto input = inputs_.front(); - MS_ASSERT(input != nullptr); - auto output = outputs_.front(); - MS_ASSERT(output != nullptr); - output->set_shape(input->shape()); - output->set_data_type(TypeId::kNumberTypeBool); - output->set_format(input->format()); - return RET_OK; -} } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/src/ops/less.h b/mindspore/lite/src/ops/less.h index 09eeeccebb..2967cfd70a 100644 --- a/mindspore/lite/src/ops/less.h +++ b/mindspore/lite/src/ops/less.h @@ -21,21 +21,20 @@ #include #include -#include "src/ops/arithmetic.h" +#include "src/ops/arithmetic_compare.h" namespace mindspore { namespace lite { -class Less : public Arithmetic { +class Less : public ArithmeticCompare { public: Less() = default; ~Less() = default; #ifdef PRIMITIVE_WRITEABLE - MS_DECLARE_PARENT(Less, Arithmetic); - explicit Less(schema::PrimitiveT *primitive) : Arithmetic(primitive) {} + MS_DECLARE_PARENT(Less, ArithmeticCompare); + explicit Less(schema::PrimitiveT *primitive) : ArithmeticCompare(primitive) {} #else int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; #endif - int InferShape(std::vector inputs_, std::vector outputs_) override; }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/src/ops/less_equal.cc b/mindspore/lite/src/ops/less_equal.cc index 7ecadad839..89a88fc6c7 100644 --- a/mindspore/lite/src/ops/less_equal.cc +++ b/mindspore/lite/src/ops/less_equal.cc @@ -37,16 +37,6 @@ PrimitiveC *LessEqualCreator(const schema::Primitive *primitive) { } Registry LessEqualRegistry(schema::PrimitiveType_LessEqual, LessEqualCreator); #endif -int LessEqual::InferShape(std::vector inputs_, std::vector outputs_) { - auto input = inputs_.front(); - MS_ASSERT(input != nullptr); - auto output = outputs_.front(); - MS_ASSERT(output != nullptr); - output->set_shape(input->shape()); - output->set_data_type(TypeId::kNumberTypeBool); - output->set_format(input->format()); - return RET_OK; -} } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/src/ops/less_equal.h b/mindspore/lite/src/ops/less_equal.h index 20589b4eca..ade4d12c4c 100644 --- a/mindspore/lite/src/ops/less_equal.h +++ b/mindspore/lite/src/ops/less_equal.h @@ -21,21 +21,20 @@ #include #include -#include "src/ops/arithmetic.h" +#include "src/ops/arithmetic_compare.h" namespace mindspore { namespace lite { -class LessEqual : public Arithmetic { +class LessEqual : public ArithmeticCompare { public: LessEqual() = default; ~LessEqual() = default; #ifdef PRIMITIVE_WRITEABLE - MS_DECLARE_PARENT(LessEqual, Arithmetic); - explicit LessEqual(schema::PrimitiveT *primitive) : Arithmetic(primitive) {} + MS_DECLARE_PARENT(LessEqual, ArithmeticCompare); + explicit LessEqual(schema::PrimitiveT *primitive) : ArithmeticCompare(primitive) {} #else int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; #endif - int InferShape(std::vector inputs_, std::vector outputs_) override; }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/src/ops/matmul.cc b/mindspore/lite/src/ops/matmul.cc index 3ffbeaa654..f6d9eb3d6a 100644 --- a/mindspore/lite/src/ops/matmul.cc +++ b/mindspore/lite/src/ops/matmul.cc @@ -112,9 +112,17 @@ int MatMul::InferShape(std::vector inputs_, std::vector outp input0->set_shape(a_shape); } - if (a_shape.size() < 2 || b_shape.size() < 2) { - MS_LOG(ERROR) << "inputs shape is invalid"; - return RET_INPUT_TENSOR_ERROR; + bool del_start = false; + bool del_end = false; + if (a_shape.size() == 1) { + a_shape.insert(a_shape.begin(), 1); + input0->set_shape(a_shape); + del_start = true; + } + if (b_shape.size() == 1) { + b_shape.push_back(1); + input1->set_shape(b_shape); + del_end = true; } for (size_t i = 0; i < (a_shape.size() - 2) && i < (b_shape.size() - 2); ++i) { if (a_shape[a_shape.size() - 3 - i] != b_shape[b_shape.size() - 3 - i]) { @@ -131,6 +139,12 @@ int MatMul::InferShape(std::vector inputs_, std::vector outp } std::vector c_shape(a_shape); c_shape[c_shape.size() - 1] = b_shape[b_shape.size() - 1]; + if (del_start) { + c_shape.erase(c_shape.begin()); + } + if (del_end) { + c_shape.pop_back(); + } output->set_shape(c_shape); return RET_OK; } diff --git a/mindspore/lite/src/ops/not_equal.cc b/mindspore/lite/src/ops/not_equal.cc index 2a7b5fb0fe..618025c400 100644 --- a/mindspore/lite/src/ops/not_equal.cc +++ b/mindspore/lite/src/ops/not_equal.cc @@ -38,16 +38,6 @@ PrimitiveC *NotEqualCreator(const schema::Primitive *primitive) { Registry NotEqualRegistry(schema::PrimitiveType_NotEqual, NotEqualCreator); #endif -int NotEqual::InferShape(std::vector inputs_, std::vector outputs_) { - auto input = inputs_.front(); - MS_ASSERT(input != nullptr); - auto output = outputs_.front(); - MS_ASSERT(output != nullptr); - output->set_shape(input->shape()); - output->set_data_type(TypeId::kNumberTypeBool); - output->set_format(input->format()); - return RET_OK; -} } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/src/ops/not_equal.h b/mindspore/lite/src/ops/not_equal.h index d6c112e4e5..464d27d685 100644 --- a/mindspore/lite/src/ops/not_equal.h +++ b/mindspore/lite/src/ops/not_equal.h @@ -21,21 +21,20 @@ #include #include -#include "src/ops/arithmetic.h" +#include "src/ops/arithmetic_compare.h" namespace mindspore { namespace lite { -class NotEqual : public Arithmetic { +class NotEqual : public ArithmeticCompare { public: NotEqual() = default; ~NotEqual() = default; #ifdef PRIMITIVE_WRITEABLE - MS_DECLARE_PARENT(NotEqual, Arithmetic); - explicit NotEqual(schema::PrimitiveT *primitive) : Arithmetic(primitive) {} + MS_DECLARE_PARENT(NotEqual, ArithmeticCompare); + explicit NotEqual(schema::PrimitiveT *primitive) : ArithmeticCompare(primitive) {} #else int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; #endif - int InferShape(std::vector inputs_, std::vector outputs_) override; }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/src/ops/populate/arithmetic_populate.cc b/mindspore/lite/src/ops/populate/arithmetic_populate.cc index 196959ef73..74dffde9d7 100644 --- a/mindspore/lite/src/ops/populate/arithmetic_populate.cc +++ b/mindspore/lite/src/ops/populate/arithmetic_populate.cc @@ -58,11 +58,11 @@ Registry RealDivParameterRegistry(schema::PrimitiveType_RealDiv, PopulateArithme Registry LogicalAndParameterRegistry(schema::PrimitiveType_LogicalAnd, PopulateArithmetic); Registry ParameterRegistry(schema::PrimitiveType_LogicalOr, PopulateArithmetic); Registry EqualParameterRegistry(schema::PrimitiveType_Equal, PopulateArithmetic); +Registry NotEqualParameterRegistry(schema::PrimitiveType_NotEqual, PopulateArithmetic); Registry LessParameterRegistry(schema::PrimitiveType_Less, PopulateArithmetic); +Registry LessEqualParameterRegistry(schema::PrimitiveType_LessEqual, PopulateArithmetic); Registry GreaterParameterRegistry(schema::PrimitiveType_Greater, PopulateArithmetic); Registry GreaterEqualParameterRegistry(schema::PrimitiveType_GreaterEqual, PopulateArithmetic); -Registry NotEqualParameterRegistry(schema::PrimitiveType_NotEqual, PopulateArithmetic); -Registry LessEqualParameterRegistry(schema::PrimitiveType_LessEqual, PopulateArithmetic); Registry MaximumParameterRegistry(schema::PrimitiveType_Maximum, PopulateArithmetic); Registry MinimumParameterRegistry(schema::PrimitiveType_Minimum, PopulateArithmetic); Registry FloorDivParameterRegistry(schema::PrimitiveType_FloorDiv, PopulateArithmetic); diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_compare_fp32.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_compare_fp32.cc index 7d7b032b9f..62df9fabf1 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_compare_fp32.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_compare_fp32.cc @@ -28,77 +28,82 @@ using mindspore::schema::PrimitiveType_LessEqual; using mindspore::schema::PrimitiveType_NotEqual; namespace mindspore::kernel { -namespace { -typedef struct { - int primitive_type_; - ArithmeticCompareFp32Func func_; -} TYPE_FUNC_INFO; -} // namespace -ArithmeticCompareFp32Func ArithmeticCompareCPUKernel::GetArithmeticCompareFun(int primitive_type) { - TYPE_FUNC_INFO type_func_table[] = { - {PrimitiveType_Equal, ElementEqualFp32}, {PrimitiveType_NotEqual, ElementNotEqualFp32}, - {PrimitiveType_Less, ElementLessFp32}, {PrimitiveType_LessEqual, ElementLessEqualFp32}, - {PrimitiveType_Greater, ElementGreaterFp32}, {PrimitiveType_GreaterEqual, ElementGreaterEqualFp32}}; - for (size_t i = 0; i < sizeof(type_func_table); i++) { - if (type_func_table[i].primitive_type_ == primitive_type) { - return type_func_table[i].func_; +int ArithmeticCompareCPUKernel::BroadcastRun(void *input0, void *input1, void *output, int dim, int out_count, + int out_thread_stride) { + if (dim > break_pos_) { + if (data_type_ == kDataTypeInt) { + return func_int32_(reinterpret_cast(input0) + out_thread_stride, + reinterpret_cast(input1) + out_thread_stride, + reinterpret_cast(output) + out_thread_stride, out_count); } + return func_fp32_(reinterpret_cast(input0) + out_thread_stride, + reinterpret_cast(input1) + out_thread_stride, + reinterpret_cast(output) + out_thread_stride, out_count); } - return nullptr; -} - -int ArithmeticCompareCPUKernel::Init() { - if (!InferShapeDone()) { - return RET_OK; + for (int i = 0; i < arithmeticParameter_->out_shape_[dim]; ++i) { + int pos0_ = arithmeticParameter_->in_shape0_[dim] == 1 ? 0 : i; + int pos1_ = arithmeticParameter_->in_shape1_[dim] == 1 ? 0 : i; + int error_code; + if (data_type_ == kDataTypeInt) { + error_code = BroadcastRun(reinterpret_cast(input0) + pos0_ * arithmeticParameter_->in_strides0_[dim], + reinterpret_cast(input1) + pos1_ * arithmeticParameter_->in_strides1_[dim], + reinterpret_cast(output) + i * arithmeticParameter_->out_strides_[dim], + dim + 1, out_count, out_thread_stride); + } else { + error_code = BroadcastRun(reinterpret_cast(input0) + pos0_ * arithmeticParameter_->in_strides0_[dim], + reinterpret_cast(input1) + pos1_ * arithmeticParameter_->in_strides1_[dim], + reinterpret_cast(output) + i * arithmeticParameter_->out_strides_[dim], + dim + 1, out_count, out_thread_stride); + } + if (error_code != RET_OK) { + return error_code; + } } - return ReSize(); + return RET_OK; } -int ArithmeticCompareCPUKernel::ReSize() { return RET_OK; } +int ArithmeticCompareCPUKernel::DoArithmetic(int task_id) { + auto element_num = out_tensors_[0]->ElementsNum(); -int ArithmeticCompareCPUKernel::DoExecute(int task_id) { - if (in_tensors_.at(0)->shape() != in_tensors_.at(1)->shape()) { - MS_LOG(ERROR) << "Compare op must inputs have the same shape, support broadcast later! "; - return RET_ERROR; - } - int elements_num = in_tensors_.at(0)->ElementsNum(); - int stride = UP_DIV(elements_num, op_parameter_->thread_num_); - int offset = task_id * stride; - int count = MSMIN(stride, elements_num - offset); - if (count <= 0) { - return RET_OK; - } - if (func_ == nullptr) { - MS_LOG(ERROR) << "Run function is null! "; + MS_ASSERT(thread_count_ != 0); + int stride = UP_DIV(element_num, thread_count_); + int count = MSMIN(stride, element_num - stride * task_id); + + if (func_fp32_ == nullptr) { + MS_LOG(ERROR) << "func_fp32_ function is nullptr!"; return RET_ERROR; } - // two inputs have the same shape, support broadcast later - auto *input0_ptr = reinterpret_cast(in_tensors_.at(0)->MutableData()); - auto *input1_ptr = reinterpret_cast(in_tensors_.at(1)->MutableData()); - auto *output_ptr = reinterpret_cast(out_tensors_.at(0)->MutableData()); - auto ret = func_(input0_ptr + offset, input1_ptr + offset, output_ptr + offset, count); - if (ret != RET_OK) { - MS_LOG(ERROR) << "Run failed, illegal input! "; - } - return ret; -} -int ArithmeticCompareRun(void *cdata, int task_id) { - auto kernel = reinterpret_cast(cdata); - auto ret = kernel->DoExecute(task_id); - if (ret != RET_OK) { - MS_LOG(ERROR) << "ArithmeticSelfRuns error task_id[" << task_id << "] error_code[" << ret << "]"; + int error_code; + if (arithmeticParameter_->broadcasting_) { // need broadcast + stride = UP_DIV(outside_, thread_count_); + int out_count = MSMIN(stride, outside_ - stride * task_id); + int out_thread_stride = stride * task_id; + if (data_type_ == kDataTypeFloat) { + error_code = BroadcastRun( + reinterpret_cast(in_tensors_[0]->data_c()), reinterpret_cast(in_tensors_[1]->data_c()), + reinterpret_cast(out_tensors_[0]->data_c()), 0, out_count, out_thread_stride); + } else { + error_code = BroadcastRun( + reinterpret_cast(in_tensors_[0]->data_c()), reinterpret_cast(in_tensors_[1]->data_c()), + reinterpret_cast(out_tensors_[0]->data_c()), 0, out_count, out_thread_stride); + } + } else { // no broadcast, neither is scalar, two same shape + if (data_type_ == kDataTypeFloat) { + error_code = func_fp32_(reinterpret_cast(in_tensors_[0]->data_c()) + stride * task_id, + reinterpret_cast(in_tensors_[1]->data_c()) + stride * task_id, + reinterpret_cast(out_tensors_[0]->data_c()) + stride * task_id, count); + } else { + error_code = func_int32_(reinterpret_cast(in_tensors_[0]->data_c()) + stride * task_id, + reinterpret_cast(in_tensors_[1]->data_c()) + stride * task_id, + reinterpret_cast(out_tensors_[0]->data_c()) + stride * task_id, count); + } } - return ret; -} - -int ArithmeticCompareCPUKernel::Run() { - auto ret = ParallelLaunch(this->context_->thread_pool_, ArithmeticCompareRun, this, op_parameter_->thread_num_); - if (ret != RET_OK) { - MS_LOG(ERROR) << "ArithmeticSelfRun error error_code[" << ret << "]"; + if (error_code != RET_OK) { + return RET_ERROR; } - return ret; + return RET_OK; } kernel::LiteKernel *CpuArithmeticCompareFp32KernelCreator(const std::vector &inputs, diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_compare_fp32.h b/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_compare_fp32.h index 42bb61237e..fad5565612 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_compare_fp32.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_compare_fp32.h @@ -18,27 +18,57 @@ #include #include "src/runtime/kernel/arm/fp32/arithmetic_fp32.h" +#include "nnacl/fp32/arithmetic_compare_fp32.h" namespace mindspore::kernel { typedef int (*ArithmeticCompareFp32Func)(const float *input0, const float *input1, uint8_t *output, int element_size); +typedef int (*ArithmeticCompareIntFunc)(const int *input0, const int *input1, uint8_t *output, int element_size); class ArithmeticCompareCPUKernel : public ArithmeticCPUKernel { public: explicit ArithmeticCompareCPUKernel(OpParameter *parameter, const std::vector &inputs, const std::vector &outputs, const lite::InnerContext *ctx, const mindspore::lite::PrimitiveC *primitive) : ArithmeticCPUKernel(parameter, inputs, outputs, ctx, primitive) { - func_ = GetArithmeticCompareFun(parameter->type_); + switch (parameter->type_) { + case PrimitiveType_Equal: + func_fp32_ = ElementEqualFp32; + func_int32_ = ElementEqualInt32; + break; + case PrimitiveType_NotEqual: + func_fp32_ = ElementNotEqualFp32; + func_int32_ = ElementNotEqualInt32; + break; + case PrimitiveType_Less: + func_fp32_ = ElementLessFp32; + func_int32_ = ElementLessInt32; + break; + case PrimitiveType_LessEqual: + func_fp32_ = ElementLessEqualFp32; + func_int32_ = ElementLessEqualInt32; + break; + case PrimitiveType_Greater: + func_fp32_ = ElementGreaterFp32; + func_int32_ = ElementGreaterInt32; + break; + case PrimitiveType_GreaterEqual: + func_fp32_ = ElementGreaterEqualFp32; + func_int32_ = ElementGreaterEqualInt32; + break; + default: + MS_LOG(ERROR) << "Error Operator type " << parameter->type_; + func_fp32_ = nullptr; + func_int32_ = nullptr; + break; + } } ~ArithmeticCompareCPUKernel() override = default; - int Init() override; - int ReSize() override; - int Run() override; - virtual int DoExecute(int task_id); + int DoArithmetic(int task_id) override; + int BroadcastRun(void *input0, void *input1, void *output, int dim, int out_count, int out_thread_stride) override; private: - ArithmeticCompareFp32Func GetArithmeticCompareFun(int primitive_type); - ArithmeticCompareFp32Func func_; + ArithmeticCompareFp32Func func_fp32_ = nullptr; + ArithmeticCompareIntFunc func_int32_ = nullptr; }; int ArithmeticCompareRun(void *cdata, int task_id); } // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_fp32.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_fp32.cc index c9a0d7eac9..5587725093 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_fp32.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_fp32.cc @@ -175,6 +175,15 @@ int ArithmeticCPUKernel::ReSize() { break; } break; + case PrimitiveType_Equal: + case PrimitiveType_Less: + case PrimitiveType_Greater: + case PrimitiveType_NotEqual: + case PrimitiveType_LessEqual: + case PrimitiveType_GreaterEqual: + arithmetic_opt_run_ = nullptr; + arithmetic_opt_run_int_ = nullptr; + break; default: break; } diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_fp32.h b/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_fp32.h index 552118fef2..b20a06043d 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_fp32.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_fp32.h @@ -167,19 +167,21 @@ class ArithmeticCPUKernel : public LiteKernel { int PreProcess() override; int ReSize() override; int Run() override; - int DoArithmetic(int task_id); + virtual int DoArithmetic(int task_id); + virtual int BroadcastRun(void *input0, void *input1, void *output, int dim, int out_count, int out_thread_stride); - private: - int BroadcastRun(void *input0, void *input1, void *output, int dim, int out_count, int out_thread_stride); + protected: int break_pos_ = 0; int outside_ = 0; int thread_count_ = 1; ArithmeticParameter *arithmeticParameter_ = nullptr; + LiteDataType data_type_ = kDataTypeFloat; + + private: ArithmeticRun arithmetic_run_ = nullptr; ArithmeticOptRun arithmetic_opt_run_ = nullptr; ArithmeticIntRun arithmetic_run_int_ = nullptr; ArithmeticOptIntRun arithmetic_opt_run_int_ = nullptr; - LiteDataType data_type_ = kDataTypeFloat; }; } // namespace mindspore::kernel #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_ARITHMETIC_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/arithmetic_self_int8.cc b/mindspore/lite/src/runtime/kernel/arm/int8/arithmetic_self_int8.cc index bdfc349ec4..94e9f4f9c5 100644 --- a/mindspore/lite/src/runtime/kernel/arm/int8/arithmetic_self_int8.cc +++ b/mindspore/lite/src/runtime/kernel/arm/int8/arithmetic_self_int8.cc @@ -146,4 +146,5 @@ REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_Sqrt, CpuArithmeticSelfInt8Kerne REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_Rsqrt, CpuArithmeticSelfInt8KernelCreator) REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_Square, CpuArithmeticSelfInt8KernelCreator) REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_LogicalNot, CpuArithmeticSelfInt8KernelCreator) +REG_KERNEL(kCPU, kNumberTypeBool, PrimitiveType_LogicalNot, CpuArithmeticSelfInt8KernelCreator) } // namespace mindspore::kernel diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_conv_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_conv_parser.cc index 3c1cbde2e8..3bf5013748 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_conv_parser.cc +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_conv_parser.cc @@ -135,17 +135,17 @@ STATUS OnnxConvParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Nod std::find_if(onnx_graph.initializer().begin(), onnx_graph.initializer().end(), [onnx_conv_weight](const onnx::TensorProto &proto) { return proto.name() == onnx_conv_weight; }); if (nodeIter == onnx_graph.initializer().end()) { - MS_LOG(ERROR) << "not find node: " << onnx_conv_weight; - return RET_ERROR; - } - std::vector weight_shape; - auto size = (*nodeIter).dims_size(); - weight_shape.reserve(size); - for (int i = 0; i < size; ++i) { - weight_shape.emplace_back((*nodeIter).dims(i)); + MS_LOG(WARNING) << "not find node: " << onnx_conv_weight; + } else { + std::vector weight_shape; + auto size = (*nodeIter).dims_size(); + weight_shape.reserve(size); + for (int i = 0; i < size; ++i) { + weight_shape.emplace_back((*nodeIter).dims(i)); + } + attr->channelOut = weight_shape[0]; + attr->channelIn = weight_shape[1] * attr->group; } - attr->channelOut = weight_shape[0]; - attr->channelIn = weight_shape[1] * attr->group; } else { auto nodeIter = std::find_if(onnx_graph.node().begin(), onnx_graph.node().end(), diff --git a/mindspore/lite/tools/optimizer/fusion/constant_folding_fusion.cc b/mindspore/lite/tools/optimizer/fusion/constant_folding_fusion.cc index 566ddb0d93..2bac86f48b 100644 --- a/mindspore/lite/tools/optimizer/fusion/constant_folding_fusion.cc +++ b/mindspore/lite/tools/optimizer/fusion/constant_folding_fusion.cc @@ -231,15 +231,6 @@ const AnfNodePtr ConstFoldPass::Process(const FuncGraphPtr &func_graph, const An output_tensors[m]->AddQuantParam(quant_arg); } } - // here, input_tensor's format need to be transposed nhwc according to fmkType, - // but for the time being, we only transpose the tensor with 0/1/2/3D. - // Others should be added in future. - for (auto &input_tensor : input_tensors) { - input_tensor->set_format(schema::Format::Format_NHWC); - if (input_tensor->shape().size() == 4) { - MS_LOG(INFO) << "init input_tensor format to nhwc"; - } - } lite_primitive->InferShape(input_tensors, output_tensors); auto primitive = lite_primitive.get(); MS_ASSERT(primitive != nullptr);