|
|
|
@ -90,6 +90,7 @@ KernelWithIndex VisitSplitKernel(const AnfNodePtr &anf_node, size_t index) {
|
|
|
|
|
bool InputCheck(const AnfNodePtr &node) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(node);
|
|
|
|
|
auto cnode = node->cast<CNodePtr>();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(cnode);
|
|
|
|
|
auto in_nums = AnfAlgo::GetInputTensorNum(node);
|
|
|
|
|
for (size_t i = 0; i < in_nums; i++) {
|
|
|
|
|
auto in_node = VisitSplitKernel(AnfAlgo::GetInputNode(cnode, i), 0).first;
|
|
|
|
@ -98,7 +99,9 @@ bool InputCheck(const AnfNodePtr &node) {
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
if (in_node->isa<CNode>()) {
|
|
|
|
|
auto in_node_name = AnfAlgo::GetCNodeName(in_node->cast<CNodePtr>());
|
|
|
|
|
auto in_cnode = in_node->cast<CNodePtr>();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(in_cnode);
|
|
|
|
|
auto in_node_name = AnfAlgo::GetCNodeName(in_cnode);
|
|
|
|
|
auto trans_input = AnfAlgo::VisitKernel(in_node, 0).first;
|
|
|
|
|
if (in_node_name == kTransDataOpName && (trans_input->isa<Parameter>() || trans_input->isa<ValueNode>())) {
|
|
|
|
|
MS_LOG(INFO) << "Data->TransData->split, can not optimizer.";
|
|
|
|
@ -107,9 +110,9 @@ bool InputCheck(const AnfNodePtr &node) {
|
|
|
|
|
if (in_node_name == prim::kPrimControlDepend->name() || in_node_name == prim::kPrimDepend->name()) {
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
if ((AnfAlgo::HasNodeAttr("non_task", cnode) && AnfAlgo::GetNodeAttr<bool>(cnode, "non_task")) ||
|
|
|
|
|
(AnfAlgo::HasNodeAttr("nop_node", cnode) && AnfAlgo::GetNodeAttr<bool>(cnode, "nop_node"))) {
|
|
|
|
|
MS_LOG(INFO) << "Input has non_task or nop_node attr, can not optimizer.";
|
|
|
|
|
if ((AnfAlgo::HasNodeAttr("non_task", in_cnode) && AnfAlgo::GetNodeAttr<bool>(in_node, "non_task")) ||
|
|
|
|
|
opt::IsNopNode(in_cnode)) {
|
|
|
|
|
MS_LOG(INFO) << "Input is nop node or has non_task attr, can not optimizer.";
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
@ -140,7 +143,7 @@ bool OutputCheck(const FuncGraphPtr &func_graph, const AnfNodePtr &node) {
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
auto op_name = AnfAlgo ::GetCNodeName(item);
|
|
|
|
|
if (InvalidOps.find(op_name) != InvalidOps.end() || AnfAlgo::IsCommunicationOp(node)) {
|
|
|
|
|
if (InvalidOps.find(op_name) != InvalidOps.end() || AnfAlgo::IsCommunicationOp(item)) {
|
|
|
|
|
MS_LOG(INFO) << "Next node is " << item->fullname_with_scope() << ", not a invalid node, can not optimizer.";
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|