|
|
|
@ -94,12 +94,11 @@ void MultiDevSSAGraphBuilder::CreateOpHandleIOs(ir::Graph *result,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::vector<std::string> MultiDevSSAGraphBuilder::FindDistTrainSendVars(
|
|
|
|
|
const std::vector<std::unique_ptr<ir::Node>> &nodes) const {
|
|
|
|
|
const std::vector<ir::Node *> &nodes) const {
|
|
|
|
|
std::vector<std::string> send_vars;
|
|
|
|
|
// since parameters are all in block 0,
|
|
|
|
|
// it's enough to only scan send ops in block 0
|
|
|
|
|
for (auto &node : nodes) {
|
|
|
|
|
if (node->NodeType() != ir::Node::Type::kOperation) continue;
|
|
|
|
|
OpDesc *op = node->Op();
|
|
|
|
|
// TODO(Yancey1989): use a graceful method to find send op,
|
|
|
|
|
// instead of the the hard code string
|
|
|
|
@ -114,10 +113,9 @@ std::vector<std::string> MultiDevSSAGraphBuilder::FindDistTrainSendVars(
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::vector<std::string> MultiDevSSAGraphBuilder::FindDistTrainRecvVars(
|
|
|
|
|
const std::vector<std::unique_ptr<ir::Node>> &nodes) const {
|
|
|
|
|
const std::vector<ir::Node *> &nodes) const {
|
|
|
|
|
std::vector<std::string> recv_vars;
|
|
|
|
|
for (auto &node : nodes) {
|
|
|
|
|
if (node->NodeType() != ir::Node::Type::kOperation) continue;
|
|
|
|
|
OpDesc *op = node->Op();
|
|
|
|
|
// TODO(Yancey1989): use a graceful method to find recv op,
|
|
|
|
|
// instead of the hard code string
|
|
|
|
@ -214,6 +212,19 @@ std::vector<ir::Node *> SortOpsAndDelayOptimizeOp(const ir::Graph &graph) {
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Verify that no operations before optimize ops depends on optimize ops.
|
|
|
|
|
std::unordered_set<ir::Node *> optimize_set(optimize_ops.begin(),
|
|
|
|
|
optimize_ops.end());
|
|
|
|
|
for (size_t i = 0; i < last_backward; ++i) {
|
|
|
|
|
for (ir::Node *in : sorted_ret[i]->inputs) {
|
|
|
|
|
for (ir::Node *pre_n : in->inputs) {
|
|
|
|
|
PADDLE_ENFORCE(optimize_set.find(pre_n) == optimize_set.end(),
|
|
|
|
|
"optimize operations cannot be depended by forward "
|
|
|
|
|
"or backward node %s -> %s",
|
|
|
|
|
pre_n->Name(), sorted_ret[i]->Name());
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
sorted_ret.insert(sorted_ret.begin() + last_backward, optimize_ops.begin(),
|
|
|
|
|
optimize_ops.end());
|
|
|
|
|
return sorted_ret;
|
|
|
|
@ -221,18 +232,16 @@ std::vector<ir::Node *> SortOpsAndDelayOptimizeOp(const ir::Graph &graph) {
|
|
|
|
|
|
|
|
|
|
std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::Apply(
|
|
|
|
|
std::unique_ptr<ir::Graph> graph) const {
|
|
|
|
|
// Rebuild the graph structure.
|
|
|
|
|
// Give the topology sort order and rebuild the graph structure.
|
|
|
|
|
std::vector<ir::Node *> sorted_ops = SortOpsAndDelayOptimizeOp(*graph);
|
|
|
|
|
auto nodes = std::move(graph->nodes);
|
|
|
|
|
graph->nodes.clear();
|
|
|
|
|
auto nodes = graph->ReleaseNodes();
|
|
|
|
|
ir::Graph &result = *graph;
|
|
|
|
|
|
|
|
|
|
for (auto &node : nodes) {
|
|
|
|
|
if (node->NodeType() == ir::Node::Type::kVariable) {
|
|
|
|
|
all_vars_.emplace(node->Name(), node->Var());
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
ir::Graph &result = *graph;
|
|
|
|
|
std::unordered_set<std::string> og_has_been_broadcast;
|
|
|
|
|
|
|
|
|
|
// We cannot invoke resize. It is a bug of GCC 4.8
|
|
|
|
@ -242,8 +251,8 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::Apply(
|
|
|
|
|
|
|
|
|
|
// find send/recv vars so that we can place the distributed training
|
|
|
|
|
// realted op in the place 0
|
|
|
|
|
auto send_vars = FindDistTrainSendVars(nodes);
|
|
|
|
|
auto recv_vars = FindDistTrainRecvVars(nodes);
|
|
|
|
|
auto send_vars = FindDistTrainSendVars(sorted_ops);
|
|
|
|
|
auto recv_vars = FindDistTrainRecvVars(sorted_ops);
|
|
|
|
|
|
|
|
|
|
std::vector<std::unordered_set<std::string>> bcast_var_name_set;
|
|
|
|
|
bcast_var_name_set.resize(places_.size());
|
|
|
|
@ -589,8 +598,7 @@ void MultiDevSSAGraphBuilder::ConnectOp(ir::Graph *result, OpHandleBase *op,
|
|
|
|
|
const std::string &prev_op_name) const {
|
|
|
|
|
for (auto &prev_op : result->Get<GraphOps>("ops")) {
|
|
|
|
|
if (prev_op->Name() == prev_op_name) {
|
|
|
|
|
auto *dep_var = new DummyVarHandle(
|
|
|
|
|
result->CreateEmptyNode("dummy", ir::Node::Type::kVariable));
|
|
|
|
|
auto *dep_var = new DummyVarHandle(result->CreateControlDepVar());
|
|
|
|
|
prev_op->AddOutput(dep_var);
|
|
|
|
|
result->Get<GraphDepVars>("dep_vars").emplace(dep_var);
|
|
|
|
|
op->AddInput(dep_var);
|
|
|
|
|