diff --git a/paddle/fluid/framework/details/async_ssa_graph_executor.cc b/paddle/fluid/framework/details/async_ssa_graph_executor.cc index b36ed8af9a..12822c64e9 100644 --- a/paddle/fluid/framework/details/async_ssa_graph_executor.cc +++ b/paddle/fluid/framework/details/async_ssa_graph_executor.cc @@ -80,6 +80,7 @@ void ProcessGraph(std::vector graphs, Scope *scope) { } } } + /* VLOG(3) << "delete all recv ops"; for (auto *node : nodes_to_delete) { // delete input edge @@ -105,6 +106,7 @@ void ProcessGraph(std::vector graphs, Scope *scope) { VLOG(3) << "delete node " << node->Name(); graphs[i]->RemoveNode(node); } + */ } // init communicator here if (send_varname_to_ctx.size() > 0) { diff --git a/paddle/fluid/framework/details/multi_devices_graph_pass.h b/paddle/fluid/framework/details/multi_devices_graph_pass.h index f7ec9d28de..0b9061ad60 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_pass.h +++ b/paddle/fluid/framework/details/multi_devices_graph_pass.h @@ -127,6 +127,10 @@ class AsyncSSAGraphBuilder : public MultiDevSSAGraphBuilderBase { bool NeedCollectiveOps() const override { return false; } bool DealWithSpecialOp(ir::Graph *result, ir::Node *node) const override { + if (node->Op()->Type() == "recv") { + node->Op()->SetAttr("do_not_run", true); + node->Op()->Flush(); + } return false; } diff --git a/paddle/fluid/operators/distributed_ops/recv_op.cc b/paddle/fluid/operators/distributed_ops/recv_op.cc index 680b484d41..afbf7a4a23 100644 --- a/paddle/fluid/operators/distributed_ops/recv_op.cc +++ b/paddle/fluid/operators/distributed_ops/recv_op.cc @@ -36,6 +36,11 @@ class RecvOp : public framework::OperatorBase { void RunImpl(const framework::Scope &scope, const platform::Place &place) const override { + bool do_not_run = Attr("do_not_run"); + if (do_not_run) { + VLOG(3) << "recv do not run!"; + return; + } std::vector epmap = Attr>("epmap"); std::vector varnames = Attr>("varnames"); @@ -126,6 +131,7 @@ This operator can get variables from server side. "(vector) " "the splited parameter varnames to be recved from pserver") .SetDefault(std::vector{}); + AddAttr("do_not_run", "").SetDefault(false); } };