|
|
|
@ -24,6 +24,68 @@ namespace paddle {
|
|
|
|
|
namespace framework {
|
|
|
|
|
namespace ir {
|
|
|
|
|
|
|
|
|
|
std::vector<std::string> FindDistTrainSendVars(
|
|
|
|
|
const std::vector<ir::Node *> &nodes) {
|
|
|
|
|
std::vector<std::string> send_vars;
|
|
|
|
|
// since parameters are all in block 0,
|
|
|
|
|
// it's enough to only scan send ops in block 0
|
|
|
|
|
for (auto &node : nodes) {
|
|
|
|
|
auto op_vars = node->Op()->InputArgumentNames();
|
|
|
|
|
send_vars.reserve(send_vars.size() +
|
|
|
|
|
std::distance(op_vars.begin(), op_vars.end()));
|
|
|
|
|
send_vars.insert(send_vars.end(), op_vars.begin(), op_vars.end());
|
|
|
|
|
}
|
|
|
|
|
return send_vars;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::vector<std::string> FindDistTrainRecvVars(
|
|
|
|
|
const std::vector<ir::Node *> &nodes) {
|
|
|
|
|
std::vector<std::string> recv_vars;
|
|
|
|
|
for (auto &node : nodes) {
|
|
|
|
|
auto op_vars = node->Op()->OutputArgumentNames();
|
|
|
|
|
recv_vars.reserve(recv_vars.size() +
|
|
|
|
|
std::distance(op_vars.begin(), op_vars.end()));
|
|
|
|
|
recv_vars.insert(recv_vars.end(), op_vars.begin(), op_vars.end());
|
|
|
|
|
}
|
|
|
|
|
return recv_vars;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool IsDistTrainOp(ir::Node *node, const std::vector<std::string> &send_vars,
|
|
|
|
|
const std::vector<std::string> &recv_vars) {
|
|
|
|
|
if (send_vars.size() == 0 || recv_vars.size() == 0) {
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/**
|
|
|
|
|
* Check any of opvars contains `.block` and in sendvars
|
|
|
|
|
*/
|
|
|
|
|
auto checker = [](const std::vector<std::string> &opvars,
|
|
|
|
|
const std::vector<std::string> &rpc_vars) -> bool {
|
|
|
|
|
for (auto &var : opvars) {
|
|
|
|
|
// a variable name with the suffix `.block` means it's a splited
|
|
|
|
|
// variable by (DistributeTranspiler)
|
|
|
|
|
// [python/paddle/fluid/transpiler/distribute_transpiler.py]
|
|
|
|
|
if (var.find(".block") != std::string::npos &&
|
|
|
|
|
std::find(rpc_vars.begin(), rpc_vars.end(), var) != rpc_vars.end()) {
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
return false;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
std::vector<std::string> input_var_names;
|
|
|
|
|
std::vector<std::string> output_var_names;
|
|
|
|
|
for (ir::Node *input : node->inputs) {
|
|
|
|
|
input_var_names.push_back(input->Name());
|
|
|
|
|
}
|
|
|
|
|
for (ir::Node *output : node->outputs) {
|
|
|
|
|
output_var_names.push_back(output->Name());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return checker(output_var_names, send_vars) ||
|
|
|
|
|
checker(input_var_names, recv_vars);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Graph::Graph(const ProgramDesc &program) : program_(program) {
|
|
|
|
|
VLOG(3) << "block in program:" << program_.Size();
|
|
|
|
|
std::unordered_map<std::string, VarDesc *> all_vars;
|
|
|
|
@ -104,6 +166,21 @@ Graph::Graph(const ProgramDesc &program) : program_(program) {
|
|
|
|
|
dep_var->outputs.push_back(fetch_bar);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::vector<std::string> send_vars = FindDistTrainSendVars(send_ops);
|
|
|
|
|
std::vector<std::string> recv_vars = FindDistTrainRecvVars(recv_ops);
|
|
|
|
|
for (ir::Node *node : Nodes()) {
|
|
|
|
|
if (IsDistTrainOp(node, send_vars, recv_vars)) {
|
|
|
|
|
if (fetch_bar && node->Name() == "concat") {
|
|
|
|
|
ir::Node *dep_var = CreateControlDepVar();
|
|
|
|
|
fetch_bar->outputs.push_back(dep_var);
|
|
|
|
|
dep_var->inputs.push_back(fetch_bar);
|
|
|
|
|
node->inputs.push_back(dep_var);
|
|
|
|
|
dep_var->outputs.push_back(node);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/**
|
|
|
|
|
* We only handle write after read(WAR), since it should not have a write
|
|
|
|
|
* after write in program. If there are write after write operators, we need
|
|
|
|
|