|
|
|
|
@ -12,7 +12,6 @@
|
|
|
|
|
// See the License for the specific language governing permissions and
|
|
|
|
|
// limitations under the License.
|
|
|
|
|
#include "paddle/fluid/framework/details/multi_devices_graph_builder.h"
|
|
|
|
|
#include <fstream>
|
|
|
|
|
#include <utility>
|
|
|
|
|
#include "paddle/fluid/framework/details/broadcast_op_handle.h"
|
|
|
|
|
#include "paddle/fluid/framework/details/computation_op_handle.h"
|
|
|
|
|
@ -79,9 +78,39 @@ void MultiDevSSAGraphBuilder::CreateOpHandleIOs(SSAGraph *result,
|
|
|
|
|
CreateOpOutput(result, op_handle, each_var_name, p, place_id);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
bool MultiDevSSAGraphBuilder::IsDistTrainOp(const OpDesc &op,
|
|
|
|
|
OpDesc *send_op) const {
|
|
|
|
|
if (send_op == nullptr) {
|
|
|
|
|
|
|
|
|
|
std::vector<std::string> MultiDevSSAGraphBuilder::FindDistTrainSendVars(
|
|
|
|
|
const ProgramDesc &program) const {
|
|
|
|
|
std::vector<std::string> send_vars;
|
|
|
|
|
for (auto *op : program.Block(0).AllOps()) {
|
|
|
|
|
if (op->Type() == "send_vars" || op->Type() == "send") {
|
|
|
|
|
auto op_vars = op->InputArgumentNames();
|
|
|
|
|
send_vars.reserve(send_vars.size() +
|
|
|
|
|
std::distance(op_vars.begin(), op_vars.end()));
|
|
|
|
|
send_vars.insert(send_vars.end(), op_vars.begin(), op_vars.end());
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
return send_vars;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::vector<std::string> MultiDevSSAGraphBuilder::FindDistTrainRecvVars(
|
|
|
|
|
const ProgramDesc &program) const {
|
|
|
|
|
std::vector<std::string> recv_vars;
|
|
|
|
|
for (auto *op : program.Block(0).AllOps()) {
|
|
|
|
|
if (op->Type() == "recv" || op->Type() == "send") {
|
|
|
|
|
auto op_vars = op->OutputArgumentNames();
|
|
|
|
|
recv_vars.reserve(recv_vars.size() +
|
|
|
|
|
std::distance(op_vars.begin(), op_vars.end()));
|
|
|
|
|
recv_vars.insert(recv_vars.end(), op_vars.begin(), op_vars.end());
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
return recv_vars;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool MultiDevSSAGraphBuilder::IsDistTrainOp(
|
|
|
|
|
const OpDesc &op, const std::vector<std::string> &send_vars,
|
|
|
|
|
const std::vector<std::string> &recv_vars) const {
|
|
|
|
|
if (send_vars.size() == 0 || recv_vars.size() == 0) {
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
@ -89,21 +118,23 @@ bool MultiDevSSAGraphBuilder::IsDistTrainOp(const OpDesc &op,
|
|
|
|
|
* Check any of opvars contains `.block` and in sendvars
|
|
|
|
|
*/
|
|
|
|
|
auto checker = [](const std::vector<std::string> &opvars,
|
|
|
|
|
const std::vector<std::string> &sendvars) -> bool {
|
|
|
|
|
const std::vector<std::string> &rpc_vars) -> bool {
|
|
|
|
|
for (auto &var : opvars) {
|
|
|
|
|
if (var.find(".block") != std::string::npos &&
|
|
|
|
|
std::find(sendvars.begin(), sendvars.end(), var) != sendvars.end()) {
|
|
|
|
|
std::find(rpc_vars.begin(), rpc_vars.end(), var) != rpc_vars.end()) {
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
return false;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
if (op.Type() == "split" || op.Type() == "split_byref") {
|
|
|
|
|
return checker(op.OutputArgumentNames(), send_op->InputArgumentNames());
|
|
|
|
|
if (op.Type() == "split" || op.Type() == "split_byref" ||
|
|
|
|
|
op.Type() == "split_selected_rows") {
|
|
|
|
|
return checker(op.OutputArgumentNames(), send_vars);
|
|
|
|
|
} else if (op.Type() == "concat") {
|
|
|
|
|
return checker(op.InputArgumentNames(), send_op->OutputArgumentNames());
|
|
|
|
|
return checker(op.InputArgumentNames(), recv_vars);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
@ -132,8 +163,10 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
|
|
|
|
|
std::unordered_map<std::string, std::vector<std::unique_ptr<VarHandle>>>>(
|
|
|
|
|
places_.size());
|
|
|
|
|
|
|
|
|
|
// Find "send" op first for split is in front of send.
|
|
|
|
|
OpDesc *send_op = GetSendOpDesc(program);
|
|
|
|
|
// find send/recv vars so that we can place the distributed training
|
|
|
|
|
// realted op in the place 0
|
|
|
|
|
auto send_vars = FindDistTrainSendVars(program);
|
|
|
|
|
auto recv_vars = FindDistTrainRecvVars(program);
|
|
|
|
|
|
|
|
|
|
size_t cur_device_id = 0;
|
|
|
|
|
std::vector<std::unordered_set<std::string>> var_name_on_devices;
|
|
|
|
|
@ -147,8 +180,9 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
|
|
|
|
|
// append rpc op if program is distributed trainer main program.
|
|
|
|
|
// always use the first device
|
|
|
|
|
CreateRPCOp(&result, *op);
|
|
|
|
|
} else if (IsDistTrainOp(*op, send_op)) {
|
|
|
|
|
CreateComputationalOps(&result, *op, 1);
|
|
|
|
|
} else if (IsDistTrainOp(*op, send_vars, recv_vars)) {
|
|
|
|
|
// CreateComputationalOps(&result, *op, 1);
|
|
|
|
|
CreateComputationalOp(&result, *op, 0);
|
|
|
|
|
} else if (IsScaleLossOp(*op)) {
|
|
|
|
|
// user can customize loss@grad if not use_default_grad_scale_
|
|
|
|
|
if (strategy_.gradient_scale_ !=
|
|
|
|
|
@ -213,9 +247,9 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
|
|
|
|
|
AddOutputToLeafOps(&result);
|
|
|
|
|
|
|
|
|
|
if (VLOG_IS_ON(10)) {
|
|
|
|
|
std::string filename = "/tmp/graph";
|
|
|
|
|
std::ofstream fout(filename);
|
|
|
|
|
PrintGraphviz(*graph, fout);
|
|
|
|
|
std::ostringstream sout;
|
|
|
|
|
PrintGraphviz(*graph, sout);
|
|
|
|
|
VLOG(10) << sout.str();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return std::unique_ptr<SSAGraph>(graph);
|
|
|
|
|
@ -274,6 +308,7 @@ OpDesc *MultiDevSSAGraphBuilder::GetSendOpDesc(
|
|
|
|
|
}
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void MultiDevSSAGraphBuilder::InsertNCCLAllReduceOp(
|
|
|
|
|
SSAGraph *result, const std::string &og) const {
|
|
|
|
|
#ifdef PADDLE_WITH_CUDA
|
|
|
|
|
@ -396,14 +431,14 @@ VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(SSAGraph *result,
|
|
|
|
|
return var;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void MultiDevSSAGraphBuilder::ConnectOp(SSAGraph *result,
|
|
|
|
|
std::string op_name) const {
|
|
|
|
|
void MultiDevSSAGraphBuilder::ConnectOp(SSAGraph *result, OpHandleBase *op,
|
|
|
|
|
const std::string &prev_op_name) const {
|
|
|
|
|
for (auto &prev_op : result->ops_) {
|
|
|
|
|
if (prev_op->Name() == op_name) {
|
|
|
|
|
if (prev_op->Name() == prev_op_name) {
|
|
|
|
|
auto *dep_var = new DummyVarHandle();
|
|
|
|
|
prev_op->AddOutput(dep_var);
|
|
|
|
|
result->dep_vars_.emplace(dep_var);
|
|
|
|
|
result->ops_.back().get()->AddInput(dep_var);
|
|
|
|
|
op->AddInput(dep_var);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
@ -412,14 +447,14 @@ void MultiDevSSAGraphBuilder::CreateRPCOp(SSAGraph *result,
|
|
|
|
|
const OpDesc &op) const {
|
|
|
|
|
auto &p = places_[0];
|
|
|
|
|
auto *s = local_scopes_[0];
|
|
|
|
|
VLOG(3) << "create rpc op: " << op.Type();
|
|
|
|
|
result->ops_.emplace_back(new RPCOpHandle(op, s, p, op.Type()));
|
|
|
|
|
|
|
|
|
|
if (op.Type() == "send_barrier") {
|
|
|
|
|
ConnectOp(result, "send_vars");
|
|
|
|
|
ConnectOp(result, result->ops_.back().get(), "send_vars");
|
|
|
|
|
} else if (op.Type() == "recv") {
|
|
|
|
|
ConnectOp(result, "send_barrier");
|
|
|
|
|
ConnectOp(result, result->ops_.back().get(), "send_barrier");
|
|
|
|
|
} else if (op.Type() == "fetch_barrier") {
|
|
|
|
|
ConnectOp(result, "recv");
|
|
|
|
|
ConnectOp(result, result->ops_.back().get(), "recv");
|
|
|
|
|
} else if (op.Type() == "send" || op.Type() == "send_vars") {
|
|
|
|
|
// do nothing
|
|
|
|
|
} else {
|
|
|
|
|
@ -429,7 +464,6 @@ void MultiDevSSAGraphBuilder::CreateRPCOp(SSAGraph *result,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// FIXME(wuyi): send op always copy from GPU 0
|
|
|
|
|
// result->ops_.emplace_back(new RPCOpHandle(op, s, p, op.Type()));
|
|
|
|
|
// Create inputs for output on original place and no ssa output
|
|
|
|
|
// is created for send op.
|
|
|
|
|
CreateOpHandleIOs(result, op, 0);
|
|
|
|
|
|