!11996 fix bug of split optimizer.

From: @liu_xiao_93
Reviewed-by: @liangchenghui,@wuxuejian
Signed-off-by: @liangchenghui
pull/11996/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit b04e2aabff

@ -90,6 +90,7 @@ KernelWithIndex VisitSplitKernel(const AnfNodePtr &anf_node, size_t index) {
bool InputCheck(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
auto in_nums = AnfAlgo::GetInputTensorNum(node);
for (size_t i = 0; i < in_nums; i++) {
auto in_node = VisitSplitKernel(AnfAlgo::GetInputNode(cnode, i), 0).first;
@ -98,7 +99,9 @@ bool InputCheck(const AnfNodePtr &node) {
return false;
}
if (in_node->isa<CNode>()) {
auto in_node_name = AnfAlgo::GetCNodeName(in_node->cast<CNodePtr>());
auto in_cnode = in_node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(in_cnode);
auto in_node_name = AnfAlgo::GetCNodeName(in_cnode);
auto trans_input = AnfAlgo::VisitKernel(in_node, 0).first;
if (in_node_name == kTransDataOpName && (trans_input->isa<Parameter>() || trans_input->isa<ValueNode>())) {
MS_LOG(INFO) << "Data->TransData->split, can not optimizer.";
@ -107,9 +110,9 @@ bool InputCheck(const AnfNodePtr &node) {
if (in_node_name == prim::kPrimControlDepend->name() || in_node_name == prim::kPrimDepend->name()) {
return false;
}
if ((AnfAlgo::HasNodeAttr("non_task", cnode) && AnfAlgo::GetNodeAttr<bool>(cnode, "non_task")) ||
(AnfAlgo::HasNodeAttr("nop_node", cnode) && AnfAlgo::GetNodeAttr<bool>(cnode, "nop_node"))) {
MS_LOG(INFO) << "Input has non_task or nop_node attr, can not optimizer.";
if ((AnfAlgo::HasNodeAttr("non_task", in_cnode) && AnfAlgo::GetNodeAttr<bool>(in_node, "non_task")) ||
opt::IsNopNode(in_cnode)) {
MS_LOG(INFO) << "Input is nop node or has non_task attr, can not optimizer.";
return false;
}
}
@ -140,7 +143,7 @@ bool OutputCheck(const FuncGraphPtr &func_graph, const AnfNodePtr &node) {
return false;
}
auto op_name = AnfAlgo ::GetCNodeName(item);
if (InvalidOps.find(op_name) != InvalidOps.end() || AnfAlgo::IsCommunicationOp(node)) {
if (InvalidOps.find(op_name) != InvalidOps.end() || AnfAlgo::IsCommunicationOp(item)) {
MS_LOG(INFO) << "Next node is " << item->fullname_with_scope() << ", not a invalid node, can not optimizer.";
return false;
}

Loading…
Cancel
Save