|
|
|
@ -81,8 +81,7 @@ bool AnfExporter::RemoveIfTupleGetItem(const CNodePtr &cnode) {
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
ValueNodePtr valueNode = utils::cast<ValueNodePtr>(indexNode);
|
|
|
|
|
mapRemoveGetItem_[tupleGetItemNode->input(1)->fullname_with_scope()] =
|
|
|
|
|
GetValue<int>(valueNode->value());
|
|
|
|
|
mapRemoveGetItem_[tupleGetItemNode->input(1)->fullname_with_scope()] = GetValue<int>(valueNode->value());
|
|
|
|
|
} else {
|
|
|
|
|
inputs.emplace_back(cnode->input(i));
|
|
|
|
|
}
|
|
|
|
@ -114,17 +113,35 @@ schema::MetaGraphT *AnfExporter::Export(const FuncGraphPtr &funcGraph) {
|
|
|
|
|
auto metaGraphT = std::make_unique<schema::MetaGraphT>();
|
|
|
|
|
for (const auto &cnode : cnodes) {
|
|
|
|
|
auto primitive = GetValueNode<PrimitivePtr>(cnode->input(0));
|
|
|
|
|
if (primitive != nullptr &&
|
|
|
|
|
RemoveNodeInAnfExporter.count(primitive->name()) != 0) {
|
|
|
|
|
if (primitive != nullptr) {
|
|
|
|
|
if (RemoveNodeInAnfExporter.count(primitive->name()) != 0) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
auto primitiveT_value = GetValueNode<std::shared_ptr<PrimitiveTValue>>(cnode->input(0));
|
|
|
|
|
auto primT = primitiveT_value->GetPrimitiveT();
|
|
|
|
|
if (primT->value.type == schema::PrimitiveType_TupleGetItem ||
|
|
|
|
|
primT->value.type == schema::PrimitiveType_MakeTuple) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
mapRemoveGetItem_.clear();
|
|
|
|
|
RemoveIfMakeTuple(cnode);
|
|
|
|
|
RemoveIfTupleGetItem(cnode);
|
|
|
|
|
if (primitive != nullptr && primitive->name() == prim::kPrimReturn->name()) {
|
|
|
|
|
|
|
|
|
|
if (primitive != nullptr) {
|
|
|
|
|
if (primitive->name() == prim::kPrimReturn->name()) {
|
|
|
|
|
AddOutPutIfReturn(metaGraphT, cnode);
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
auto primitiveT_value = GetValueNode<std::shared_ptr<PrimitiveTValue>>(cnode->input(0));
|
|
|
|
|
auto primT = primitiveT_value->GetPrimitiveT();
|
|
|
|
|
if (primT->value.type == schema::PrimitiveType_Return) {
|
|
|
|
|
AddOutPutIfReturn(metaGraphT, cnode);
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto node = std::make_unique<schema::CNodeT>();
|
|
|
|
|
node->name = cnode->fullname_with_scope();
|
|
|
|
@ -134,27 +151,24 @@ schema::MetaGraphT *AnfExporter::Export(const FuncGraphPtr &funcGraph) {
|
|
|
|
|
primitive = GetValueNode<PrimitivePtr>(cnode->input(0));
|
|
|
|
|
MS_ASSERT(primitive != nullptr);
|
|
|
|
|
std::string opType = primitive->name();
|
|
|
|
|
auto nodeParser =
|
|
|
|
|
AnfNodePopulaterRegistry::GetInstance()->GetNodePopulater(opType);
|
|
|
|
|
auto nodeParser = AnfNodePopulaterRegistry::GetInstance()->GetNodePopulater(opType);
|
|
|
|
|
if (nodeParser == nullptr) {
|
|
|
|
|
MS_LOG(ERROR) << "Find op parser failed, opType: " << opType;
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
std::vector<schema::TensorT *> outputs;
|
|
|
|
|
if (utils::isa<abstract::AbstractSequeue>(cnode->abstract())) {
|
|
|
|
|
auto abstract_cnode =
|
|
|
|
|
utils::cast<abstract::AbstractSequeuePtr>(cnode->abstract());
|
|
|
|
|
auto abstract_cnode = utils::cast<abstract::AbstractSequeuePtr>(cnode->abstract());
|
|
|
|
|
outputs.resize(abstract_cnode->size());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
nodeParser->Parse(cnode, node.get(), &outputs);
|
|
|
|
|
SetOpInputNode(cnode, metaGraphT.get(), node.get());
|
|
|
|
|
SetOpOutputNode(outputs, metaGraphT.get(), node.get());
|
|
|
|
|
SetOpOutputNode(cnode, outputs, metaGraphT.get(), node.get());
|
|
|
|
|
metaGraphT->nodes.emplace_back(std::move(node));
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
auto primitiveT_value =
|
|
|
|
|
GetValueNode<std::shared_ptr<PrimitiveTValue>>(cnode->input(0));
|
|
|
|
|
auto primitiveT_value = GetValueNode<std::shared_ptr<PrimitiveTValue>>(cnode->input(0));
|
|
|
|
|
if (primitiveT_value == nullptr) {
|
|
|
|
|
MS_LOG(ERROR) << "PrimitiveT_value is nullptr";
|
|
|
|
|
return nullptr;
|
|
|
|
@ -166,11 +180,10 @@ schema::MetaGraphT *AnfExporter::Export(const FuncGraphPtr &funcGraph) {
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
node->primitive =
|
|
|
|
|
std::unique_ptr<schema::PrimitiveT>(primitiveT_value->GetPrimitiveT());
|
|
|
|
|
node->primitive = std::unique_ptr<schema::PrimitiveT>(primitiveT_value->GetPrimitiveT());
|
|
|
|
|
std::vector<schema::TensorT *> outputs;
|
|
|
|
|
SetOpInputNode(cnode, metaGraphT.get(), node.get());
|
|
|
|
|
SetOpOutputNode(outputs, metaGraphT.get(), node.get());
|
|
|
|
|
SetOpOutputNode(cnode, outputs, metaGraphT.get(), node.get());
|
|
|
|
|
|
|
|
|
|
// add quant param
|
|
|
|
|
node->quantType = primitiveT_value->GetQuantType();
|
|
|
|
@ -244,9 +257,7 @@ schema::MetaGraphT *AnfExporter::Export(const FuncGraphPtr &funcGraph) {
|
|
|
|
|
return metaGraphT.release();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void AnfExporter::SetOpInputNode(const CNodePtr &cnode,
|
|
|
|
|
schema::MetaGraphT *meta_graph,
|
|
|
|
|
schema::CNodeT *fbNode) {
|
|
|
|
|
void AnfExporter::SetOpInputNode(const CNodePtr &cnode, schema::MetaGraphT *meta_graph, schema::CNodeT *fbNode) {
|
|
|
|
|
MS_ASSERT(nullptr != meta_graph);
|
|
|
|
|
MS_ASSERT(nullptr != fbNode);
|
|
|
|
|
if (cnode->inputs().size() <= 1) {
|
|
|
|
@ -281,38 +292,30 @@ void AnfExporter::SetOpInputNode(const CNodePtr &cnode,
|
|
|
|
|
auto paramTensor = std::make_unique<schema::TensorT>();
|
|
|
|
|
auto abstractBase = paramNode->abstract();
|
|
|
|
|
if (abstractBase == nullptr) {
|
|
|
|
|
MS_LOG(ERROR) << "Abstract of parameter is nullptr, "
|
|
|
|
|
<< paramNode->name();
|
|
|
|
|
MS_LOG(ERROR) << "Abstract of parameter is nullptr, " << paramNode->name();
|
|
|
|
|
MS_ASSERT(false);
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
if (!utils::isa<abstract::AbstractTensorPtr>(abstractBase)) {
|
|
|
|
|
MS_LOG(ERROR) << "Abstract of parameter should be anstract tensor, "
|
|
|
|
|
<< paramNode->name();
|
|
|
|
|
MS_LOG(ERROR) << "Abstract of parameter should be anstract tensor, " << paramNode->name();
|
|
|
|
|
MS_ASSERT(false);
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
auto abstractTensor =
|
|
|
|
|
utils::cast<abstract::AbstractTensorPtr>(abstractBase);
|
|
|
|
|
auto abstractTensor = utils::cast<abstract::AbstractTensorPtr>(abstractBase);
|
|
|
|
|
auto typePtr = abstractTensor->element()->GetTypeTrack();
|
|
|
|
|
MS_ASSERT(typePtr != nullptr);
|
|
|
|
|
paramTensor->dataType = typePtr->type_id();
|
|
|
|
|
if (!utils::isa<abstract::ShapePtr>(abstractTensor->BuildShape())) {
|
|
|
|
|
MS_LOG(ERROR) << "Shape of Abstract of parameter should be ShapePtr, "
|
|
|
|
|
<< paramNode->name();
|
|
|
|
|
MS_LOG(ERROR) << "Shape of Abstract of parameter should be ShapePtr, " << paramNode->name();
|
|
|
|
|
MS_ASSERT(false);
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
paramTensor->dims =
|
|
|
|
|
utils::cast<abstract::ShapePtr>(abstractTensor->BuildShape())
|
|
|
|
|
->shape();
|
|
|
|
|
auto paramValue =
|
|
|
|
|
std::dynamic_pointer_cast<ParamValueLite>(paramNode->default_param());
|
|
|
|
|
paramTensor->dims = utils::cast<abstract::ShapePtr>(abstractTensor->BuildShape())->shape();
|
|
|
|
|
auto paramValue = std::dynamic_pointer_cast<ParamValueLite>(paramNode->default_param());
|
|
|
|
|
if (paramValue != nullptr) {
|
|
|
|
|
paramTensor->nodeType = schema::NodeType_ValueNode;
|
|
|
|
|
paramTensor->data.resize(paramValue->tensor_size());
|
|
|
|
|
memcpy(paramTensor->data.data(), paramValue->tensor_addr(),
|
|
|
|
|
paramValue->tensor_size());
|
|
|
|
|
memcpy(paramTensor->data.data(), paramValue->tensor_addr(), paramValue->tensor_size());
|
|
|
|
|
for (auto &ite : paramValue->quant_param()) {
|
|
|
|
|
auto quantPar = std::make_unique<schema::QuantParamT>();
|
|
|
|
|
quantPar->scale = ite->scale;
|
|
|
|
@ -326,8 +329,7 @@ void AnfExporter::SetOpInputNode(const CNodePtr &cnode,
|
|
|
|
|
paramTensor->dataType = paramValue->tensor_type();
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
nodeIdMap[paramNode->fullname_with_scope()] =
|
|
|
|
|
meta_graph->allTensors.size();
|
|
|
|
|
nodeIdMap[paramNode->fullname_with_scope()] = meta_graph->allTensors.size();
|
|
|
|
|
fbNode->inputIndex.emplace_back(meta_graph->allTensors.size());
|
|
|
|
|
meta_graph->allTensors.emplace_back(std::move(paramTensor));
|
|
|
|
|
} else if (inputNode->isa<ValueNode>()) {
|
|
|
|
@ -336,19 +338,15 @@ void AnfExporter::SetOpInputNode(const CNodePtr &cnode,
|
|
|
|
|
auto value = valueNode->value();
|
|
|
|
|
if (value->isa<lite::tensor::Tensor>()) {
|
|
|
|
|
auto valueAbstract = valueNode->abstract();
|
|
|
|
|
auto abstractTensor =
|
|
|
|
|
utils::cast<abstract::AbstractTensorPtr>(valueAbstract);
|
|
|
|
|
auto abstractTensor = utils::cast<abstract::AbstractTensorPtr>(valueAbstract);
|
|
|
|
|
auto typePtr = abstractTensor->element()->GetTypeTrack();
|
|
|
|
|
paramTensor->dataType = typePtr->type_id();
|
|
|
|
|
paramTensor->dims =
|
|
|
|
|
utils::cast<abstract::ShapePtr>(abstractTensor->BuildShape())
|
|
|
|
|
->shape();
|
|
|
|
|
paramTensor->dims = utils::cast<abstract::ShapePtr>(abstractTensor->BuildShape())->shape();
|
|
|
|
|
paramTensor->nodeType = schema::NodeType_ValueNode;
|
|
|
|
|
auto data = value->cast<lite::tensor::TensorPtr>();
|
|
|
|
|
paramTensor->data.resize(data->Size());
|
|
|
|
|
memcpy(paramTensor->data.data(), data->Data(), data->Size());
|
|
|
|
|
nodeIdMap[valueNode->fullname_with_scope()] =
|
|
|
|
|
meta_graph->allTensors.size();
|
|
|
|
|
nodeIdMap[valueNode->fullname_with_scope()] = meta_graph->allTensors.size();
|
|
|
|
|
fbNode->inputIndex.emplace_back(meta_graph->allTensors.size());
|
|
|
|
|
meta_graph->allTensors.emplace_back(std::move(paramTensor));
|
|
|
|
|
} else if (value->isa<mindspore::Int32Imm>()) {
|
|
|
|
@ -376,31 +374,45 @@ void AnfExporter::SetOpInputNode(const CNodePtr &cnode,
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void AnfExporter::SetOpOutputNode(
|
|
|
|
|
const std::vector<schema::TensorT *> &outputTensors,
|
|
|
|
|
schema::MetaGraphT *graph, schema::CNodeT *cnode) {
|
|
|
|
|
void AnfExporter::SetOpOutputNode(const CNodePtr &cnode, const std::vector<schema::TensorT *> &outputTensors,
|
|
|
|
|
schema::MetaGraphT *graph, schema::CNodeT *fbnode) {
|
|
|
|
|
MS_ASSERT(nullptr != graph);
|
|
|
|
|
MS_ASSERT(nullptr != cnode);
|
|
|
|
|
std::string cnodeName = cnode->name;
|
|
|
|
|
MS_ASSERT(nullptr != fbnode);
|
|
|
|
|
std::string cnodeName = fbnode->name;
|
|
|
|
|
if (!outputTensors.empty()) {
|
|
|
|
|
int i = 0;
|
|
|
|
|
for (auto outputTensor : outputTensors) {
|
|
|
|
|
std::string name = cnodeName + "_o:" + std::to_string(i);
|
|
|
|
|
auto msTensor = new schema::TensorT();
|
|
|
|
|
msTensor->nodeType = schema::NodeType_Parameter;
|
|
|
|
|
nodeIdMap[name] = graph->allTensors.size();
|
|
|
|
|
cnode->outputIndex.emplace_back(graph->allTensors.size());
|
|
|
|
|
graph->allTensors.emplace_back(msTensor);
|
|
|
|
|
fbnode->outputIndex.emplace_back(graph->allTensors.size());
|
|
|
|
|
graph->allTensors.emplace_back(outputTensor);
|
|
|
|
|
i++;
|
|
|
|
|
}
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (utils::isa<abstract::AbstractTuple>(cnode->abstract())) {
|
|
|
|
|
auto tuple = std::reinterpret_pointer_cast<abstract::AbstractTuple>(cnode->abstract());
|
|
|
|
|
for (int i = 0; i < tuple->size(); i++) {
|
|
|
|
|
auto msTensor = new schema::TensorT();
|
|
|
|
|
msTensor->nodeType = schema::NodeType_Parameter;
|
|
|
|
|
fbnode->outputIndex.emplace_back(graph->allTensors.size());
|
|
|
|
|
if (tuple->size() == 1) {
|
|
|
|
|
nodeIdMap[cnodeName] = graph->allTensors.size();
|
|
|
|
|
} else {
|
|
|
|
|
std::string name = cnodeName + "_o:" + std::to_string(i);
|
|
|
|
|
nodeIdMap[name] = graph->allTensors.size();
|
|
|
|
|
}
|
|
|
|
|
graph->allTensors.emplace_back(msTensor);
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
auto msTensor = new schema::TensorT();
|
|
|
|
|
msTensor->nodeType = schema::NodeType_Parameter;
|
|
|
|
|
cnode->outputIndex.emplace_back(graph->allTensors.size());
|
|
|
|
|
fbnode->outputIndex.emplace_back(graph->allTensors.size());
|
|
|
|
|
nodeIdMap[cnodeName] = graph->allTensors.size();
|
|
|
|
|
graph->allTensors.emplace_back(msTensor);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
schema::MetaGraphT *Export(const FuncGraphPtr &funcGraph) {
|
|
|
|
|
AnfExporter anfExporter;
|
|
|
|
|