|
|
|
@ -112,32 +112,13 @@ const AnfNodePtr OptimizeDependence::Process(const FuncGraphPtr &func_graph, con
|
|
|
|
|
}
|
|
|
|
|
auto input_num = AnfAlgo::GetInputTensorNum(depend_cnode);
|
|
|
|
|
while (index < input_num) {
|
|
|
|
|
auto replacing_node = AnfAlgo::GetInputNode(depend_cnode, index);
|
|
|
|
|
++index;
|
|
|
|
|
MS_EXCEPTION_IF_NULL(replacing_node);
|
|
|
|
|
if (!replacing_node->isa<CNode>()) {
|
|
|
|
|
new_depend_inputs.push_back(replacing_node);
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
auto replacing_cnode = replacing_node->cast<CNodePtr>();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(replacing_cnode);
|
|
|
|
|
// Deal with the make_tuple with TransData or Cast inputs.
|
|
|
|
|
auto make_tuple_replace_node = ReplaceMakeTuple(func_graph, replacing_cnode);
|
|
|
|
|
if (make_tuple_replace_node != nullptr) {
|
|
|
|
|
new_depend_inputs.push_back(make_tuple_replace_node);
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
AnfNodePtr replace_node = GetReplaceNode(replacing_cnode);
|
|
|
|
|
if (replace_node == nullptr) {
|
|
|
|
|
new_depend_inputs.push_back(replacing_node);
|
|
|
|
|
MS_LOG(DEBUG) << "Can not find the TransData or Cast with single output node. Depend node: "
|
|
|
|
|
<< node->DebugString();
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
auto replace_node = GetConvertNode(func_graph, node, index);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(replace_node);
|
|
|
|
|
new_depend_inputs.push_back(replace_node);
|
|
|
|
|
++index;
|
|
|
|
|
}
|
|
|
|
|
auto kernel_graph = func_graph->cast<std::shared_ptr<session::KernelGraph>>();
|
|
|
|
|
CNodePtr new_depend;
|
|
|
|
|
CNodePtr new_depend = nullptr;
|
|
|
|
|
if (kernel_graph == nullptr) {
|
|
|
|
|
new_depend = func_graph->NewCNode(new_depend_inputs);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(new_depend);
|
|
|
|
@ -150,5 +131,31 @@ const AnfNodePtr OptimizeDependence::Process(const FuncGraphPtr &func_graph, con
|
|
|
|
|
}
|
|
|
|
|
return new_depend;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
const AnfNodePtr OptimizeDependence::GetConvertNode(const FuncGraphPtr &graph, const AnfNodePtr &node,
|
|
|
|
|
const size_t index) const {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(graph);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(node);
|
|
|
|
|
auto depend_cnode = node->cast<CNodePtr>();
|
|
|
|
|
auto replacing_node = AnfAlgo::GetInputNode(depend_cnode, index);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(replacing_node);
|
|
|
|
|
if (!replacing_node->isa<CNode>()) {
|
|
|
|
|
return replacing_node;
|
|
|
|
|
}
|
|
|
|
|
auto replacing_cnode = replacing_node->cast<CNodePtr>();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(replacing_cnode);
|
|
|
|
|
// Deal with the make_tuple with TransData or Cast inputs.
|
|
|
|
|
auto make_tuple_replace_node = ReplaceMakeTuple(graph, replacing_cnode);
|
|
|
|
|
if (make_tuple_replace_node != nullptr) {
|
|
|
|
|
return make_tuple_replace_node;
|
|
|
|
|
}
|
|
|
|
|
AnfNodePtr replace_node = GetReplaceNode(replacing_cnode);
|
|
|
|
|
if (replace_node == nullptr) {
|
|
|
|
|
MS_LOG(DEBUG) << "Can not find the TransData or Cast with single output node. Depend node: " << node->DebugString();
|
|
|
|
|
return replacing_node;
|
|
|
|
|
}
|
|
|
|
|
return replace_node;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
} // namespace opt
|
|
|
|
|
} // namespace mindspore
|
|
|
|
|