diff --git a/mindspore/lite/nnacl/fp32/arithmetic_fp32.c b/mindspore/lite/nnacl/fp32/arithmetic_fp32.c index acfbf32106..37d7b9d303 100644 --- a/mindspore/lite/nnacl/fp32/arithmetic_fp32.c +++ b/mindspore/lite/nnacl/fp32/arithmetic_fp32.c @@ -725,6 +725,22 @@ int ElementSub(const float *input0, const float *input1, float *output, const in return NNACL_OK; } +int ElementSubInt(const int *input0, const int *input1, int *output, const int element_size) { + int index = 0; +#ifdef ENABLE_NEON + for (; index <= element_size - 4; index += C4NUM) { + int32x4_t vin0 = vld1q_s32(input0 + index); + int32x4_t vin1 = vld1q_s32(input1 + index); + int32x4_t vout = vsubq_s32(vin0, vin1); + vst1q_s32(output + index, vout); + } +#endif + for (; index < element_size; index++) { + output[index] = input0[index] - input1[index]; + } + return NNACL_OK; +} + int ElementSubRelu(const float *input0, const float *input1, float *output, const int element_size) { int index = 0; #ifdef ENABLE_NEON diff --git a/mindspore/lite/nnacl/fp32/arithmetic_fp32.h b/mindspore/lite/nnacl/fp32/arithmetic_fp32.h index 0ca64c2a62..f7030e2ee5 100644 --- a/mindspore/lite/nnacl/fp32/arithmetic_fp32.h +++ b/mindspore/lite/nnacl/fp32/arithmetic_fp32.h @@ -77,6 +77,7 @@ int BroadcastAddInt8(const int8_t *input0, const int8_t *input1, int8_t *tile_in int8_t *output, int element_size, ArithmeticParameter *param); int ElementSub(const float *input0, const float *input1, float *output, const int element_size); +int ElementSubInt(const int *input0, const int *input1, int *output, const int element_size); int ElementSubRelu(const float *input0, const float *input1, float *output, const int element_size); int ElementSubRelu6(const float *input0, const float *input1, float *output, const int element_size); int BroadcastSub(const float *input0, const float *input1, float *tile_input0, float *tile_input1, float *output, diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_fp32.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_fp32.cc index 5587725093..b802488a0e 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_fp32.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_fp32.cc @@ -363,6 +363,8 @@ REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Add, CpuArithmeticFp32KernelC REG_KERNEL(kCPU, kNumberTypeInt, PrimitiveType_Add, CpuArithmeticFp32KernelCreator) REG_KERNEL(kCPU, kNumberTypeInt32, PrimitiveType_Add, CpuArithmeticFp32KernelCreator) REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Sub, CpuArithmeticFp32KernelCreator) +REG_KERNEL(kCPU, kNumberTypeInt, PrimitiveType_Sub, CpuArithmeticFp32KernelCreator) +REG_KERNEL(kCPU, kNumberTypeInt32, PrimitiveType_Sub, CpuArithmeticFp32KernelCreator) REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Div, CpuArithmeticFp32KernelCreator) REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_RealDiv, CpuArithmeticFp32KernelCreator) REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_LogicalAnd, CpuArithmeticFp32KernelCreator) diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_fp32.h b/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_fp32.h index b20a06043d..3908058fcf 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_fp32.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_fp32.h @@ -97,6 +97,7 @@ class ArithmeticCPUKernel : public LiteKernel { break; default: arithmetic_run_ = ElementSub; + arithmetic_run_int_ = ElementSubInt; break; } break; diff --git a/mindspore/lite/tools/common/graph_util.cc b/mindspore/lite/tools/common/graph_util.cc index dc16534641..30632ee916 100644 --- a/mindspore/lite/tools/common/graph_util.cc +++ b/mindspore/lite/tools/common/graph_util.cc @@ -285,9 +285,9 @@ STATUS RemoveTensor(schema::MetaGraphT *graphT, std::vector toDeleteTe } } // update nodes indexes - for (auto nodeIter = graphT->nodes.begin(); nodeIter != graphT->nodes.end(); nodeIter++) { + for (auto node_iter = graphT->nodes.begin(); node_iter != graphT->nodes.end(); node_iter++) { // update nodes input indexes - UpdateNodeIndex((*nodeIter).get(), deleteIdx); + UpdateNodeIndex((*node_iter).get(), deleteIdx); } // update deleteTensorIdx for (auto selfIt = toDeleteTensorIdxes.begin(); selfIt != toDeleteTensorIdxes.end(); selfIt++) { @@ -374,10 +374,10 @@ NodeIter InsertNode(schema::MetaGraphT *graphT, uint32_t existNodeIdx, InsertPla MS_LOG(ERROR) << "nodeIdx out of range: " << existNodeIdx; return graphT->nodes.end(); } - auto nodeIter = graphT->nodes.begin() + existNodeIdx; - MS_ASSERT(nodeIter != graphT->nodes.begin()); - MS_ASSERT((*nodeIter) != nullptr); - return InsertNode(graphT, nodeIter, place, inoutIndex, std::move(toAddNode), errorCode); + auto node_iter = graphT->nodes.begin() + existNodeIdx; + MS_ASSERT(node_iter != graphT->nodes.begin()); + MS_ASSERT((*node_iter) != nullptr); + return InsertNode(graphT, node_iter, place, inoutIndex, std::move(toAddNode), errorCode); } NodeIter InsertNode(schema::MetaGraphT *graphT, NodeIter existNodeIter, InsertPlace place, size_t inoutIndexIdx, diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_conv_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_conv_parser.cc index 3bf5013748..b4944a93d1 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_conv_parser.cc +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_conv_parser.cc @@ -131,33 +131,33 @@ STATUS OnnxConvParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Nod const auto &onnx_conv_weight = onnx_node.input(1); if (onnx_node.op_type() == "Conv") { - auto nodeIter = + auto node_iter = std::find_if(onnx_graph.initializer().begin(), onnx_graph.initializer().end(), [onnx_conv_weight](const onnx::TensorProto &proto) { return proto.name() == onnx_conv_weight; }); - if (nodeIter == onnx_graph.initializer().end()) { + if (node_iter == onnx_graph.initializer().end()) { MS_LOG(WARNING) << "not find node: " << onnx_conv_weight; } else { std::vector weight_shape; - auto size = (*nodeIter).dims_size(); + auto size = (*node_iter).dims_size(); weight_shape.reserve(size); for (int i = 0; i < size; ++i) { - weight_shape.emplace_back((*nodeIter).dims(i)); + weight_shape.emplace_back((*node_iter).dims(i)); } attr->channelOut = weight_shape[0]; attr->channelIn = weight_shape[1] * attr->group; } } else { - auto nodeIter = + auto node_iter = std::find_if(onnx_graph.node().begin(), onnx_graph.node().end(), [onnx_conv_weight](const onnx::NodeProto &proto) { return proto.output(0) == onnx_conv_weight; }); - if (nodeIter == onnx_graph.node().end()) { + if (node_iter == onnx_graph.node().end()) { MS_LOG(ERROR) << "can not find node: " << onnx_conv_weight; return RET_ERROR; } std::vector dims; - auto iter = std::find_if((*nodeIter).attribute().begin(), (*nodeIter).attribute().end(), + auto iter = std::find_if((*node_iter).attribute().begin(), (*node_iter).attribute().end(), [](const onnx::AttributeProto &attr) { return attr.name() == "shape"; }); - if (iter != (*nodeIter).attribute().end()) { + if (iter != (*node_iter).attribute().end()) { if (iter->ints().begin() == nullptr || iter->ints().end() == nullptr) { MS_LOG(ERROR) << "dims insert failed"; return RET_ERROR; diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_deconv_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_deconv_parser.cc index b966a8fd5d..268129438e 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_deconv_parser.cc +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_deconv_parser.cc @@ -133,18 +133,18 @@ STATUS OnnxDeConvParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::N } const auto &onnx_conv_weight = onnx_node.input(1); - auto nodeIter = + auto node_iter = std::find_if(onnx_graph.initializer().begin(), onnx_graph.initializer().end(), [onnx_conv_weight](const onnx::TensorProto &proto) { return proto.name() == onnx_conv_weight; }); - if (nodeIter == onnx_graph.initializer().end()) { + if (node_iter == onnx_graph.initializer().end()) { MS_LOG(ERROR) << "not find node: " << onnx_conv_weight.c_str(); return RET_ERROR; } std::vector weight_shape; - auto size = (*nodeIter).dims_size(); + auto size = (*node_iter).dims_size(); weight_shape.reserve(size); for (int i = 0; i < size; ++i) { - weight_shape.emplace_back((*nodeIter).dims(i)); + weight_shape.emplace_back((*node_iter).dims(i)); } if (weight_shape.size() != 4) { MS_LOG(ERROR) << "weight_shape.size() should be 4, but is " << weight_shape.size(); diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_expand_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_expand_parser.cc index 46fb4288ec..87414ae823 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_expand_parser.cc +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_expand_parser.cc @@ -41,14 +41,14 @@ STATUS OnnxExpandParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::N std::vector dst_shape; const auto &onnx_expand_power = onnx_node.input(1); - auto nodeIter = + auto node_iter = std::find_if(onnx_graph.node().begin(), onnx_graph.node().end(), [onnx_expand_power](const onnx::NodeProto &proto) { return proto.output(0) == onnx_expand_power; }); - if (nodeIter == onnx_graph.node().end()) { + if (node_iter == onnx_graph.node().end()) { MS_LOG(ERROR) << "can not find node: " << onnx_expand_power; return RET_ERROR; } - for (const auto &attrPower : nodeIter->attribute()) { + for (const auto &attrPower : node_iter->attribute()) { if (attrPower.name() == "value") { const auto &t = attrPower.t(); auto *dataPtr = reinterpret_cast(t.raw_data().data());