!9078 [MS][LITE]Fix sub support int32

From: @gongdaguo
Reviewed-by: @hangangqiang,@zhang_xue_tong
Signed-off-by: @hangangqiang
pull/9078/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit 366afe5d9a

@ -725,6 +725,22 @@ int ElementSub(const float *input0, const float *input1, float *output, const in
return NNACL_OK;
}
int ElementSubInt(const int *input0, const int *input1, int *output, const int element_size) {
int index = 0;
#ifdef ENABLE_NEON
for (; index <= element_size - 4; index += C4NUM) {
int32x4_t vin0 = vld1q_s32(input0 + index);
int32x4_t vin1 = vld1q_s32(input1 + index);
int32x4_t vout = vsubq_s32(vin0, vin1);
vst1q_s32(output + index, vout);
}
#endif
for (; index < element_size; index++) {
output[index] = input0[index] - input1[index];
}
return NNACL_OK;
}
int ElementSubRelu(const float *input0, const float *input1, float *output, const int element_size) {
int index = 0;
#ifdef ENABLE_NEON

@ -77,6 +77,7 @@ int BroadcastAddInt8(const int8_t *input0, const int8_t *input1, int8_t *tile_in
int8_t *output, int element_size, ArithmeticParameter *param);
int ElementSub(const float *input0, const float *input1, float *output, const int element_size);
int ElementSubInt(const int *input0, const int *input1, int *output, const int element_size);
int ElementSubRelu(const float *input0, const float *input1, float *output, const int element_size);
int ElementSubRelu6(const float *input0, const float *input1, float *output, const int element_size);
int BroadcastSub(const float *input0, const float *input1, float *tile_input0, float *tile_input1, float *output,

@ -363,6 +363,8 @@ REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Add, CpuArithmeticFp32KernelC
REG_KERNEL(kCPU, kNumberTypeInt, PrimitiveType_Add, CpuArithmeticFp32KernelCreator)
REG_KERNEL(kCPU, kNumberTypeInt32, PrimitiveType_Add, CpuArithmeticFp32KernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Sub, CpuArithmeticFp32KernelCreator)
REG_KERNEL(kCPU, kNumberTypeInt, PrimitiveType_Sub, CpuArithmeticFp32KernelCreator)
REG_KERNEL(kCPU, kNumberTypeInt32, PrimitiveType_Sub, CpuArithmeticFp32KernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Div, CpuArithmeticFp32KernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_RealDiv, CpuArithmeticFp32KernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_LogicalAnd, CpuArithmeticFp32KernelCreator)

@ -97,6 +97,7 @@ class ArithmeticCPUKernel : public LiteKernel {
break;
default:
arithmetic_run_ = ElementSub;
arithmetic_run_int_ = ElementSubInt;
break;
}
break;

@ -285,9 +285,9 @@ STATUS RemoveTensor(schema::MetaGraphT *graphT, std::vector<uint32_t> toDeleteTe
}
}
// update nodes indexes
for (auto nodeIter = graphT->nodes.begin(); nodeIter != graphT->nodes.end(); nodeIter++) {
for (auto node_iter = graphT->nodes.begin(); node_iter != graphT->nodes.end(); node_iter++) {
// update nodes input indexes
UpdateNodeIndex((*nodeIter).get(), deleteIdx);
UpdateNodeIndex((*node_iter).get(), deleteIdx);
}
// update deleteTensorIdx
for (auto selfIt = toDeleteTensorIdxes.begin(); selfIt != toDeleteTensorIdxes.end(); selfIt++) {
@ -374,10 +374,10 @@ NodeIter InsertNode(schema::MetaGraphT *graphT, uint32_t existNodeIdx, InsertPla
MS_LOG(ERROR) << "nodeIdx out of range: " << existNodeIdx;
return graphT->nodes.end();
}
auto nodeIter = graphT->nodes.begin() + existNodeIdx;
MS_ASSERT(nodeIter != graphT->nodes.begin());
MS_ASSERT((*nodeIter) != nullptr);
return InsertNode(graphT, nodeIter, place, inoutIndex, std::move(toAddNode), errorCode);
auto node_iter = graphT->nodes.begin() + existNodeIdx;
MS_ASSERT(node_iter != graphT->nodes.begin());
MS_ASSERT((*node_iter) != nullptr);
return InsertNode(graphT, node_iter, place, inoutIndex, std::move(toAddNode), errorCode);
}
NodeIter InsertNode(schema::MetaGraphT *graphT, NodeIter existNodeIter, InsertPlace place, size_t inoutIndexIdx,

@ -131,33 +131,33 @@ STATUS OnnxConvParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Nod
const auto &onnx_conv_weight = onnx_node.input(1);
if (onnx_node.op_type() == "Conv") {
auto nodeIter =
auto node_iter =
std::find_if(onnx_graph.initializer().begin(), onnx_graph.initializer().end(),
[onnx_conv_weight](const onnx::TensorProto &proto) { return proto.name() == onnx_conv_weight; });
if (nodeIter == onnx_graph.initializer().end()) {
if (node_iter == onnx_graph.initializer().end()) {
MS_LOG(WARNING) << "not find node: " << onnx_conv_weight;
} else {
std::vector<int> weight_shape;
auto size = (*nodeIter).dims_size();
auto size = (*node_iter).dims_size();
weight_shape.reserve(size);
for (int i = 0; i < size; ++i) {
weight_shape.emplace_back((*nodeIter).dims(i));
weight_shape.emplace_back((*node_iter).dims(i));
}
attr->channelOut = weight_shape[0];
attr->channelIn = weight_shape[1] * attr->group;
}
} else {
auto nodeIter =
auto node_iter =
std::find_if(onnx_graph.node().begin(), onnx_graph.node().end(),
[onnx_conv_weight](const onnx::NodeProto &proto) { return proto.output(0) == onnx_conv_weight; });
if (nodeIter == onnx_graph.node().end()) {
if (node_iter == onnx_graph.node().end()) {
MS_LOG(ERROR) << "can not find node: " << onnx_conv_weight;
return RET_ERROR;
}
std::vector<int> dims;
auto iter = std::find_if((*nodeIter).attribute().begin(), (*nodeIter).attribute().end(),
auto iter = std::find_if((*node_iter).attribute().begin(), (*node_iter).attribute().end(),
[](const onnx::AttributeProto &attr) { return attr.name() == "shape"; });
if (iter != (*nodeIter).attribute().end()) {
if (iter != (*node_iter).attribute().end()) {
if (iter->ints().begin() == nullptr || iter->ints().end() == nullptr) {
MS_LOG(ERROR) << "dims insert failed";
return RET_ERROR;

@ -133,18 +133,18 @@ STATUS OnnxDeConvParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::N
}
const auto &onnx_conv_weight = onnx_node.input(1);
auto nodeIter =
auto node_iter =
std::find_if(onnx_graph.initializer().begin(), onnx_graph.initializer().end(),
[onnx_conv_weight](const onnx::TensorProto &proto) { return proto.name() == onnx_conv_weight; });
if (nodeIter == onnx_graph.initializer().end()) {
if (node_iter == onnx_graph.initializer().end()) {
MS_LOG(ERROR) << "not find node: " << onnx_conv_weight.c_str();
return RET_ERROR;
}
std::vector<int> weight_shape;
auto size = (*nodeIter).dims_size();
auto size = (*node_iter).dims_size();
weight_shape.reserve(size);
for (int i = 0; i < size; ++i) {
weight_shape.emplace_back((*nodeIter).dims(i));
weight_shape.emplace_back((*node_iter).dims(i));
}
if (weight_shape.size() != 4) {
MS_LOG(ERROR) << "weight_shape.size() should be 4, but is " << weight_shape.size();

@ -41,14 +41,14 @@ STATUS OnnxExpandParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::N
std::vector<int> dst_shape;
const auto &onnx_expand_power = onnx_node.input(1);
auto nodeIter =
auto node_iter =
std::find_if(onnx_graph.node().begin(), onnx_graph.node().end(),
[onnx_expand_power](const onnx::NodeProto &proto) { return proto.output(0) == onnx_expand_power; });
if (nodeIter == onnx_graph.node().end()) {
if (node_iter == onnx_graph.node().end()) {
MS_LOG(ERROR) << "can not find node: " << onnx_expand_power;
return RET_ERROR;
}
for (const auto &attrPower : nodeIter->attribute()) {
for (const auto &attrPower : node_iter->attribute()) {
if (attrPower.name() == "value") {
const auto &t = attrPower.t();
auto *dataPtr = reinterpret_cast<const int64_t *>(t.raw_data().data());

Loading…
Cancel
Save