diff --git a/paddle/fluid/framework/details/async_ssa_graph_executor.cc b/paddle/fluid/framework/details/async_ssa_graph_executor.cc index 52641260a6..e9aad5d264 100644 --- a/paddle/fluid/framework/details/async_ssa_graph_executor.cc +++ b/paddle/fluid/framework/details/async_ssa_graph_executor.cc @@ -81,10 +81,6 @@ void ProcessGraph(std::vector graphs, Scope *scope) { nodes_to_delete.push_back(node); VLOG(3) << "find and remove an recv op: " << recv_varname_to_ctx[recv_var_name]; - } else if (node->Name() == "lookup_table" || node->Name() == "nce" || - node->Name() == "hierarchical_sigmoid") { - VLOG(0) << "set " << node->Name() << " op remote_prefetch to false"; - node->Op()->SetAttr("remote_prefetch", false); } } } diff --git a/paddle/fluid/framework/details/multi_devices_graph_pass.h b/paddle/fluid/framework/details/multi_devices_graph_pass.h index a3fe9e8b13..82d003fad7 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_pass.h +++ b/paddle/fluid/framework/details/multi_devices_graph_pass.h @@ -127,8 +127,13 @@ class AsyncSSAGraphBuilder : public MultiDevSSAGraphBuilderBase { bool DealWithSpecialOp(ir::Graph *result, ir::Node *node) const override { if (node->Op()->Type() == "recv") { + VLOG(0) << "set recv op do_not_run to true"; node->Op()->SetAttr("do_not_run", true); node->Op()->Flush(); + } else if (node->Name() == "lookup_table" || node->Name() == "nce" || + node->Name() == "hierarchical_sigmoid") { + VLOG(0) << "set " << node->Name() << " op remote_prefetch to false"; + node->Op()->SetAttr("remote_prefetch", false); } return false; }