|
|
|
@ -55,60 +55,6 @@ void AnfExporter::RemoveIfMakeTuple(const CNodePtr &cnode) {
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool AnfExporter::RemoveIfTupleGetItem(const CNodePtr &cnode) {
|
|
|
|
|
MS_ASSERT(cnode != nullptr);
|
|
|
|
|
bool has_tuple_get_item = false;
|
|
|
|
|
std::vector<AnfNodePtr> inputs;
|
|
|
|
|
inputs.clear();
|
|
|
|
|
inputs.emplace_back(cnode->input(0));
|
|
|
|
|
for (size_t i = 1; i < cnode->inputs().size(); ++i) {
|
|
|
|
|
AnfNodePtr input_node = cnode->input(i);
|
|
|
|
|
if (!input_node->isa<CNode>()) {
|
|
|
|
|
inputs.emplace_back(cnode->input(i));
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
auto tuple_get_item_node = utils::cast<CNodePtr>(input_node);
|
|
|
|
|
if (IsPrimitiveCNode(tuple_get_item_node, schema::PrimitiveType_TupleGetItem)) {
|
|
|
|
|
has_tuple_get_item = true;
|
|
|
|
|
inputs.emplace_back(tuple_get_item_node->input(1));
|
|
|
|
|
AnfNodePtr indexNode = tuple_get_item_node->input(2);
|
|
|
|
|
if (!utils::isa<ValueNode>(indexNode)) {
|
|
|
|
|
MS_LOG(ERROR) << "TupleGetItem's input 2 is not valuenode";
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
auto value_node = utils::cast<ValueNodePtr>(indexNode);
|
|
|
|
|
} else {
|
|
|
|
|
inputs.emplace_back(cnode->input(i));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
if (has_tuple_get_item) {
|
|
|
|
|
cnode->set_inputs(inputs);
|
|
|
|
|
}
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool AnfExporter::AddOutPutIfReturn(const std::unique_ptr<schema::MetaGraphT> &meta_graphT, const CNodePtr &cnode) {
|
|
|
|
|
MS_ASSERT(meta_graphT != nullptr);
|
|
|
|
|
MS_ASSERT(cnode != nullptr);
|
|
|
|
|
for (size_t i = 1; i < cnode->inputs().size(); ++i) {
|
|
|
|
|
auto input_anode = cnode->input(i);
|
|
|
|
|
if (!input_anode->isa<CNode>()) {
|
|
|
|
|
MS_LOG(ERROR) << "Node of Return's input is not CNode";
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
auto input_cnode = utils::cast<CNodePtr>(input_anode);
|
|
|
|
|
std::string input_name = input_anode->fullname_with_scope();
|
|
|
|
|
auto iter = node_id_map_.find(input_name);
|
|
|
|
|
if (iter == node_id_map_.end()) {
|
|
|
|
|
MS_LOG(ERROR) << "Could not find output node";
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
auto graph_output = iter->second;
|
|
|
|
|
meta_graphT->outputIndex.emplace_back(graph_output);
|
|
|
|
|
}
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
int AnfExporter::ConvertQuantParam(const std::unique_ptr<schema::MetaGraphT> &meta_graph,
|
|
|
|
|
const std::shared_ptr<PrimitiveC> primitive,
|
|
|
|
|
const std::unique_ptr<schema::CNodeT> &dst_node) {
|
|
|
|
@ -182,6 +128,28 @@ void AnfExporter::SetGraphInputIndex(const std::unique_ptr<schema::MetaGraphT> &
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void AnfExporter::SetGraphoutputIndex(const CNodePtr &cnode, const std::unique_ptr<schema::MetaGraphT> &meta_graphT,
|
|
|
|
|
schema::CNodeT *return_node) {
|
|
|
|
|
MS_ASSERT(nullptr != meta_graph);
|
|
|
|
|
MS_ASSERT(nullptr != return_node);
|
|
|
|
|
for (size_t i = 1; i < cnode->inputs().size(); i++) {
|
|
|
|
|
auto input_node = cnode->input(i);
|
|
|
|
|
if (input_node->isa<CNode>()) {
|
|
|
|
|
auto ret = ConvertInputCNode(input_node, return_node);
|
|
|
|
|
if (ret != RET_OK) {
|
|
|
|
|
MS_LOG(ERROR) << "obtain outputs failed";
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
MS_LOG(ERROR) << "the node " << input_node->fullname_with_scope().c_str() << "is not output node";
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
for (size_t i = 0; i < return_node->inputIndex.size(); ++i) {
|
|
|
|
|
meta_graphT->outputIndex.push_back(return_node->inputIndex[i]);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
schema::MetaGraphT *AnfExporter::Export(const FuncGraphPtr &func_graph) {
|
|
|
|
|
auto cnodes = func_graph->GetOrderedCnodes();
|
|
|
|
|
auto meta_graphT = std::make_unique<schema::MetaGraphT>();
|
|
|
|
@ -202,24 +170,22 @@ schema::MetaGraphT *AnfExporter::Export(const FuncGraphPtr &func_graph) {
|
|
|
|
|
}
|
|
|
|
|
RemoveIfMakeTuple(cnode);
|
|
|
|
|
|
|
|
|
|
auto node = std::make_unique<schema::CNodeT>();
|
|
|
|
|
|
|
|
|
|
if (primT->value.type == schema::PrimitiveType_Return) {
|
|
|
|
|
AddOutPutIfReturn(meta_graphT, cnode);
|
|
|
|
|
node->name = "return_node";
|
|
|
|
|
SetGraphoutputIndex(cnode, meta_graphT, node.get());
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto node = std::make_unique<schema::CNodeT>();
|
|
|
|
|
node->name = cnode->fullname_with_scope();
|
|
|
|
|
node->nodeType = schema::NodeType_CNode;
|
|
|
|
|
|
|
|
|
|
node->name = cnode->fullname_with_scope();
|
|
|
|
|
node->primitive = std::unique_ptr<schema::PrimitiveT>(primT);
|
|
|
|
|
auto ret = SetOpInputNode(cnode, meta_graphT, node.get());
|
|
|
|
|
if (ret != RET_OK) {
|
|
|
|
|
MS_LOG(ERROR) << "SetOpInputNode failed";
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
SetOpOutputNode(cnode, meta_graphT, node.get());
|
|
|
|
|
|
|
|
|
|
ret = ConvertQuantParam(meta_graphT, primitiveT_value, node);
|
|
|
|
|
if (ret != RET_OK) {
|
|
|
|
|
MS_LOG(ERROR) << "ConvertQuantParam failed";
|
|
|
|
|