From b3468fab894bdf4ea39f8c7efa34291bf6eeb495 Mon Sep 17 00:00:00 2001 From: yankai Date: Tue, 11 Aug 2020 16:49:43 +0800 Subject: [PATCH] fix mindspore models runtime on_device --- .../src/common/anf_exporter/anf_exporter.cc | 4 +++- .../anf_populater/anf_reshape_populater.cc | 22 +++++++++++++++++-- .../anf_importer/import_from_protobuf.cc | 3 ++- .../anf_importer/import_from_protobuf.h | 3 ++- mindspore/lite/tools/common/node_util.h | 2 +- .../node/weight_format_pass.cc | 5 ++++- 6 files changed, 32 insertions(+), 7 deletions(-) diff --git a/mindspore/lite/src/common/anf_exporter/anf_exporter.cc b/mindspore/lite/src/common/anf_exporter/anf_exporter.cc index a28d40d7d4..d565d16795 100644 --- a/mindspore/lite/src/common/anf_exporter/anf_exporter.cc +++ b/mindspore/lite/src/common/anf_exporter/anf_exporter.cc @@ -385,9 +385,11 @@ void AnfExporter::SetOpOutputNode(const CNodePtr &cnode, const std::vectornodeType = schema::NodeType_Parameter; nodeIdMap[name] = graph->allTensors.size(); fbnode->outputIndex.emplace_back(graph->allTensors.size()); - graph->allTensors.emplace_back(outputTensor); + graph->allTensors.emplace_back(msTensor); i++; } return; diff --git a/mindspore/lite/src/common/anf_exporter/anf_populater/anf_reshape_populater.cc b/mindspore/lite/src/common/anf_exporter/anf_populater/anf_reshape_populater.cc index de5007ca5f..6669d9f11c 100644 --- a/mindspore/lite/src/common/anf_exporter/anf_populater/anf_reshape_populater.cc +++ b/mindspore/lite/src/common/anf_exporter/anf_populater/anf_reshape_populater.cc @@ -23,10 +23,28 @@ namespace mindspore::lite { int mindspore::lite::AnfReshapePopulater::Parse(mindspore::CNodePtr cnodePtr, schema::CNodeT *node, std::vector *outputs) { - auto attr = std::make_unique(); + auto attr = std::make_unique(); + MS_ASSERT(cnodePtr->size() == kAnfPopulaterThree); + auto inputNode = cnodePtr->input(kAnfPopulaterTwo); + if (inputNode->isa()) { + auto valueNode = inputNode->cast(); + MS_ASSERT(valueNode != nullptr); + auto val = valueNode->value(); + MS_ASSERT(val != nullptr); + if (val->isa()) { + auto tuple = val->cast(); + MS_ASSERT(tuple != nullptr); + for (size_t i = 0; i < tuple->size(); ++i) { + auto elem = tuple->value()[i]->cast(); + MS_ASSERT(elem != nullptr); + attr->shape.emplace_back(static_cast(elem->value())); + } + } + } + node->nodeType = schema::NodeType_CNode; node->primitive = std::make_unique(); - node->primitive->value.type = schema::PrimitiveType_Flatten; + node->primitive->value.type = schema::PrimitiveType_Reshape; node->primitive->value.value = attr.release(); return 0; } diff --git a/mindspore/lite/src/common/anf_importer/import_from_protobuf.cc b/mindspore/lite/src/common/anf_importer/import_from_protobuf.cc index 904e0bd0ad..28eae27b3a 100644 --- a/mindspore/lite/src/common/anf_importer/import_from_protobuf.cc +++ b/mindspore/lite/src/common/anf_importer/import_from_protobuf.cc @@ -639,7 +639,7 @@ bool AnfImporterFromProtobuf::ImportNodesForGraph(const FuncGraphPtr &outputFunc BuildReturnForFuncGraph(outputFuncGraph, importProto, cnode_ptr); return true; } -#endif +#else #define PARSE_ONNXATTR_IN_SCALAR_FORM(type, valuetype) \ void ParseAttrInScalar_##type##_##valuetype(const PrimitivePtr &prim, const std::string &attr_name, \ @@ -1108,6 +1108,7 @@ bool AnfImporterFromProtobuf::ImportNodesForGraph(const FuncGraphPtr &outputFunc BuildReturnForFuncGraph(outputFuncGraph, importProto, cnode_ptr); return true; } +#endif bool AnfImporterFromProtobuf::BuildFuncGraph(const FuncGraphPtr &outputFuncGraph, const onnx::GraphProto &importProto) { MS_EXCEPTION_IF_NULL(outputFuncGraph); diff --git a/mindspore/lite/src/common/anf_importer/import_from_protobuf.h b/mindspore/lite/src/common/anf_importer/import_from_protobuf.h index 4513c79f17..24502ffa90 100644 --- a/mindspore/lite/src/common/anf_importer/import_from_protobuf.h +++ b/mindspore/lite/src/common/anf_importer/import_from_protobuf.h @@ -77,7 +77,7 @@ class AnfImporterFromProtobuf : public AnfImporter { const onnx::TensorProto &attr_tensor); std::unordered_map GetAbstractForCNode(const onnx::AttributeProto &attr_proto); -#endif +#else bool ImportParametersForGraph(const FuncGraphPtr &outputFuncGraph, const onnx::GraphProto &importProto); bool ImportNodesForGraph(const FuncGraphPtr &outputFuncGraph, const onnx::GraphProto &importProto); bool BuildParameterForFuncGraph(const ParameterPtr &node, const onnx::ValueInfoProto &value_proto); @@ -100,6 +100,7 @@ class AnfImporterFromProtobuf : public AnfImporter { bool ObtainValueNodeInTypeForm(const string &value_node_name, const onnx::TensorProto &attr_tensor); abstract::AbstractTensorPtr GetAbstractForCNode(const onnx::AttributeProto &attr_proto); +#endif private: diff --git a/mindspore/lite/tools/common/node_util.h b/mindspore/lite/tools/common/node_util.h index 43e9aef558..1f017a54c5 100644 --- a/mindspore/lite/tools/common/node_util.h +++ b/mindspore/lite/tools/common/node_util.h @@ -232,7 +232,7 @@ static STATUS TransFilterData(schema::TensorT *tensor, kTransFilterType type, in if (type == kCKHW2HWCK) { p2Buff = buf.get() + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (c * filterK) + (k)); - } else if (type == kKCHW2KHWC) { + } else if (type == kCKHW2KHWC) { p2Buff = buf.get() + ((k * filterH * filterW * filterC) + (h * filterW * filterK) + (w * filterC) + (c)); } else { diff --git a/mindspore/lite/tools/converter/legacy_optimizer/node/weight_format_pass.cc b/mindspore/lite/tools/converter/legacy_optimizer/node/weight_format_pass.cc index 2e894a24f4..7add49877f 100644 --- a/mindspore/lite/tools/converter/legacy_optimizer/node/weight_format_pass.cc +++ b/mindspore/lite/tools/converter/legacy_optimizer/node/weight_format_pass.cc @@ -350,6 +350,9 @@ int WeightFormatPass::NonQuantDataFormatTrans(GraphNode *graphNode) { // todo(00445839): consider varible weight condition } } else if (opType == schema::PrimitiveType_DepthwiseConv2D) { // weight should be CKHW + if (graphNode->subGraph->fmkType == converter::FmkType_MS) { + weightTensor->format = schema::Format_CKHW; + } if (weightTensor->format == schema::Format_CKHW) { // from caffe or onnx or ms status = TransFilterFormat(weightTensor.get(), kCKHW2KHWC); } else if (weightTensor->format == schema::Format_KCHW) { @@ -362,7 +365,7 @@ int WeightFormatPass::NonQuantDataFormatTrans(GraphNode *graphNode) { } if (status == 0) { node->primitive->value.AsDepthwiseConv2D()->format = schema::Format_NHWC; - weightTensor->format = schema::Format_CKHW; + weightTensor->format = schema::Format_KHWC; } else { MS_LOG(WARNING) << "TransFilter HWCKToCKHW failed, node : " << node->name.c_str(); // todo(00445839): consider varible weight condition