|
|
|
@ -70,8 +70,7 @@ MultiDevSSAGraphBuilder::MultiDevSSAGraphBuilder(
|
|
|
|
|
void MultiDevSSAGraphBuilder::CreateOpHandleIOs(Graph *result, const OpDesc &op,
|
|
|
|
|
size_t place_id) const {
|
|
|
|
|
auto p = places_[place_id];
|
|
|
|
|
auto *op_handle =
|
|
|
|
|
boost::any_cast<GraphOps *>(result->attrs["ops"])->back().get();
|
|
|
|
|
auto *op_handle = result->Get<GraphOps>("ops").back().get();
|
|
|
|
|
op_handle->SetDeviceContext(p,
|
|
|
|
|
platform::DeviceContextPool::Instance().Get(p));
|
|
|
|
|
|
|
|
|
@ -179,13 +178,9 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
|
|
|
|
|
std::unordered_set<std::string> og_has_been_broadcast;
|
|
|
|
|
|
|
|
|
|
// We cannot invoke resize. It is a bug of GCC 4.8
|
|
|
|
|
result.attrs["vars"] = new std::vector<
|
|
|
|
|
std::unordered_map<std::string, std::vector<std::unique_ptr<VarHandle>>>>(
|
|
|
|
|
places_.size());
|
|
|
|
|
result.attrs["dep_vars"] =
|
|
|
|
|
new std::unordered_set<std::unique_ptr<VarHandleBase>>();
|
|
|
|
|
result.attrs["ops"] = new std::vector<std::unique_ptr<OpHandleBase>>();
|
|
|
|
|
|
|
|
|
|
result.Set("vars", new GraphVars(places_.size()));
|
|
|
|
|
result.Set("dep_vars", new GraphDepVars);
|
|
|
|
|
result.Set("ops", new GraphOps);
|
|
|
|
|
// find send/recv vars so that we can place the distributed training
|
|
|
|
|
// realted op in the place 0
|
|
|
|
|
auto send_vars = FindDistTrainSendVars(program);
|
|
|
|
@ -308,13 +303,9 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
|
|
|
|
|
AddOutputToLeafOps(&result);
|
|
|
|
|
|
|
|
|
|
std::unique_ptr<SSAGraph> ssa_graph(new SSAGraph);
|
|
|
|
|
ssa_graph->vars_ =
|
|
|
|
|
std::move(*boost::any_cast<GraphVars *>(graph->attrs["vars"]));
|
|
|
|
|
ssa_graph->ops_ =
|
|
|
|
|
std::move(*boost::any_cast<GraphOps *>(graph->attrs["ops"]));
|
|
|
|
|
ssa_graph->dep_vars_ =
|
|
|
|
|
std::move(*boost::any_cast<GraphDepVars *>(graph->attrs["dep_vars"]));
|
|
|
|
|
|
|
|
|
|
ssa_graph->vars_ = std::move(*graph->Erase<GraphVars>("vars"));
|
|
|
|
|
ssa_graph->ops_ = std::move(*graph->Erase<GraphOps>("ops"));
|
|
|
|
|
ssa_graph->dep_vars_ = std::move(*graph->Erase<GraphDepVars>("dep_vars"));
|
|
|
|
|
return std::move(ssa_graph);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -347,20 +338,15 @@ void MultiDevSSAGraphBuilder::CreateBroadcastOp(Graph *result,
|
|
|
|
|
#else
|
|
|
|
|
auto *op_handle = new BroadcastOpHandle(local_scopes_, places_);
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
boost::any_cast<GraphOps *>(result->attrs["ops"])->emplace_back(op_handle);
|
|
|
|
|
auto *in = boost::any_cast<GraphVars *>(result->attrs["vars"])
|
|
|
|
|
->at(src_dev_id)
|
|
|
|
|
.at(p_name)
|
|
|
|
|
.back()
|
|
|
|
|
.get();
|
|
|
|
|
result->Get<GraphOps>("ops").emplace_back(op_handle);
|
|
|
|
|
auto *in =
|
|
|
|
|
result->Get<GraphVars>("vars").at(src_dev_id).at(p_name).back().get();
|
|
|
|
|
op_handle->AddInput(in);
|
|
|
|
|
|
|
|
|
|
for (size_t i = 0; i < places_.size(); ++i) {
|
|
|
|
|
auto &p = places_[i];
|
|
|
|
|
SetCommunicationContext(op_handle, p);
|
|
|
|
|
auto &vars =
|
|
|
|
|
boost::any_cast<GraphVars *>(result->attrs["vars"])->at(i).at(p_name);
|
|
|
|
|
auto &vars = result->Get<GraphVars>("vars").at(i).at(p_name);
|
|
|
|
|
auto *out_var = new VarHandle(vars.size(), i, p_name, p);
|
|
|
|
|
vars.emplace_back(out_var);
|
|
|
|
|
op_handle->AddOutput(out_var);
|
|
|
|
@ -370,28 +356,26 @@ void MultiDevSSAGraphBuilder::CreateBroadcastOp(Graph *result,
|
|
|
|
|
void MultiDevSSAGraphBuilder::CreateComputationalOp(Graph *result,
|
|
|
|
|
const OpDesc &op,
|
|
|
|
|
int dev_id) const {
|
|
|
|
|
boost::any_cast<GraphOps *>(result->attrs["ops"])
|
|
|
|
|
->emplace_back(
|
|
|
|
|
new ComputationOpHandle(op, local_scopes_[dev_id], places_[dev_id]));
|
|
|
|
|
result->Get<GraphOps>("ops").emplace_back(
|
|
|
|
|
new ComputationOpHandle(op, local_scopes_[dev_id], places_[dev_id]));
|
|
|
|
|
CreateOpHandleIOs(result, op, dev_id);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void MultiDevSSAGraphBuilder::InsertAllReduceOp(Graph *result,
|
|
|
|
|
const std::string &og) const {
|
|
|
|
|
#ifdef PADDLE_WITH_CUDA
|
|
|
|
|
boost::any_cast<GraphOps *>(result->attrs["ops"])
|
|
|
|
|
->emplace_back(new AllReduceOpHandle(local_scopes_, places_, nccl_ctxs_));
|
|
|
|
|
result->Get<GraphOps>("ops").emplace_back(
|
|
|
|
|
new AllReduceOpHandle(local_scopes_, places_, nccl_ctxs_));
|
|
|
|
|
#else
|
|
|
|
|
boost::any_cast<GraphOps *>(result->attrs["ops"])
|
|
|
|
|
->emplace_back(new AllReduceOpHandle(local_scopes_, places_));
|
|
|
|
|
result->Get<GraphOps>("ops").emplace_back(
|
|
|
|
|
new AllReduceOpHandle(local_scopes_, places_));
|
|
|
|
|
#endif
|
|
|
|
|
auto *op_handle =
|
|
|
|
|
boost::any_cast<GraphOps *>(result->attrs["ops"])->back().get();
|
|
|
|
|
auto *op_handle = result->Get<GraphOps>("ops").back().get();
|
|
|
|
|
|
|
|
|
|
for (size_t i = 0; i < places_.size(); ++i) {
|
|
|
|
|
auto &p = places_[i];
|
|
|
|
|
SetCommunicationContext(op_handle, p);
|
|
|
|
|
auto &vars = (*boost::any_cast<GraphVars *>(result->attrs["vars"]))[i][og];
|
|
|
|
|
auto &vars = result->Get<GraphVars>("vars")[i][og];
|
|
|
|
|
PADDLE_ENFORCE(!vars.empty());
|
|
|
|
|
auto &prev_grad = vars.back();
|
|
|
|
|
op_handle->AddInput(prev_grad.get());
|
|
|
|
@ -405,21 +389,18 @@ void MultiDevSSAGraphBuilder::InsertAllReduceOp(Graph *result,
|
|
|
|
|
void MultiDevSSAGraphBuilder::InsertDataBalanceOp(
|
|
|
|
|
Graph *result, const std::vector<std::string> &datas) const {
|
|
|
|
|
#ifdef PADDLE_WITH_CUDA
|
|
|
|
|
boost::any_cast<GraphOps *>(result->attrs["ops"])
|
|
|
|
|
->emplace_back(
|
|
|
|
|
new DataBalanceOpHandle(local_scopes_, places_, nccl_ctxs_));
|
|
|
|
|
result->Get<GraphOps>("ops").emplace_back(
|
|
|
|
|
new DataBalanceOpHandle(local_scopes_, places_, nccl_ctxs_));
|
|
|
|
|
#else
|
|
|
|
|
boost::any_cast<GraphOps *>(result->attrs["ops"])
|
|
|
|
|
->emplace_back(new DataBalanceOpHandle(local_scopes_, places_));
|
|
|
|
|
result->Get<GraphOps>("ops").emplace_back(
|
|
|
|
|
new DataBalanceOpHandle(local_scopes_, places_));
|
|
|
|
|
#endif
|
|
|
|
|
auto *op_handle =
|
|
|
|
|
boost::any_cast<GraphOps *>(result->attrs["ops"])->back().get();
|
|
|
|
|
auto *op_handle = result->Get<GraphOps>("ops").back().get();
|
|
|
|
|
for (size_t i = 0; i < places_.size(); ++i) {
|
|
|
|
|
auto &p = places_[i];
|
|
|
|
|
SetCommunicationContext(op_handle, p);
|
|
|
|
|
for (const std::string &d_name : datas) {
|
|
|
|
|
auto &vars =
|
|
|
|
|
(*boost::any_cast<GraphVars *>(result->attrs["vars"]))[i][d_name];
|
|
|
|
|
auto &vars = result->Get<GraphVars>("vars")[i][d_name];
|
|
|
|
|
PADDLE_ENFORCE(!vars.empty());
|
|
|
|
|
op_handle->AddInput(vars.back().get());
|
|
|
|
|
auto var = new VarHandle(vars.size(), i, d_name, p);
|
|
|
|
@ -480,7 +461,7 @@ void MultiDevSSAGraphBuilder::CreateScaleLossGradOp(Graph *result) const {
|
|
|
|
|
auto *op_handle =
|
|
|
|
|
new ScaleLossGradOpHandle(local_scopes_.size(), local_scopes_[i],
|
|
|
|
|
places_[i], communication_dev_ctx);
|
|
|
|
|
boost::any_cast<GraphOps *>(result->attrs["ops"])->emplace_back(op_handle);
|
|
|
|
|
result->Get<GraphOps>("ops").emplace_back(op_handle);
|
|
|
|
|
|
|
|
|
|
// FIXME: Currently ScaleLossGradOp only use device_count as scale
|
|
|
|
|
// factor. So it does not depend on any other operators.
|
|
|
|
@ -499,8 +480,8 @@ void MultiDevSSAGraphBuilder::CreateComputationalOps(Graph *result,
|
|
|
|
|
for (size_t scope_idx = 0; scope_idx < num_places; ++scope_idx) {
|
|
|
|
|
auto p = places_[scope_idx];
|
|
|
|
|
auto s = local_scopes_[scope_idx];
|
|
|
|
|
boost::any_cast<GraphOps *>(result->attrs["ops"])
|
|
|
|
|
->emplace_back(new ComputationOpHandle(op, s, p));
|
|
|
|
|
result->Get<GraphOps>("ops").emplace_back(
|
|
|
|
|
new ComputationOpHandle(op, s, p));
|
|
|
|
|
CreateOpHandleIOs(result, op, scope_idx);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
@ -509,25 +490,23 @@ VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(Graph *result,
|
|
|
|
|
const std::string &og,
|
|
|
|
|
int dst_dev_id) const {
|
|
|
|
|
#ifdef PADDLE_WITH_CUDA
|
|
|
|
|
boost::any_cast<GraphOps *>(result->attrs["ops"])
|
|
|
|
|
->emplace_back(new ReduceOpHandle(local_scopes_, places_, nccl_ctxs_));
|
|
|
|
|
result->Get<GraphOps>("ops").emplace_back(
|
|
|
|
|
new ReduceOpHandle(local_scopes_, places_, nccl_ctxs_));
|
|
|
|
|
#else
|
|
|
|
|
boost::any_cast<GraphOps *>(result->attrs["ops"])
|
|
|
|
|
->emplace_back(new ReduceOpHandle(local_scopes_, places_));
|
|
|
|
|
result->Get<GraphOps>("ops").emplace_back(
|
|
|
|
|
new ReduceOpHandle(local_scopes_, places_));
|
|
|
|
|
#endif
|
|
|
|
|
auto *op_handle =
|
|
|
|
|
boost::any_cast<GraphOps *>(result->attrs["ops"])->back().get();
|
|
|
|
|
auto *op_handle = result->Get<GraphOps>("ops").back().get();
|
|
|
|
|
|
|
|
|
|
for (size_t i = 0; i < places_.size(); ++i) {
|
|
|
|
|
auto &p = places_[i];
|
|
|
|
|
SetCommunicationContext(op_handle, p);
|
|
|
|
|
auto &vars = (*boost::any_cast<GraphVars *>(result->attrs["vars"]))[i][og];
|
|
|
|
|
auto &vars = result->Get<GraphVars>("vars")[i][og];
|
|
|
|
|
PADDLE_ENFORCE(!vars.empty());
|
|
|
|
|
auto &prev_grad = vars.back();
|
|
|
|
|
op_handle->AddInput(prev_grad.get());
|
|
|
|
|
}
|
|
|
|
|
auto &vars =
|
|
|
|
|
(*boost::any_cast<GraphVars *>(result->attrs["vars"]))[dst_dev_id][og];
|
|
|
|
|
auto &vars = result->Get<GraphVars>("vars")[dst_dev_id][og];
|
|
|
|
|
auto var = new VarHandle(vars.size(), dst_dev_id, og, places_[dst_dev_id]);
|
|
|
|
|
vars.emplace_back(var);
|
|
|
|
|
op_handle->AddOutput(var);
|
|
|
|
@ -538,12 +517,11 @@ VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(Graph *result,
|
|
|
|
|
// on it.
|
|
|
|
|
void MultiDevSSAGraphBuilder::ConnectOp(Graph *result, OpHandleBase *op,
|
|
|
|
|
const std::string &prev_op_name) const {
|
|
|
|
|
for (auto &prev_op : (*boost::any_cast<GraphOps *>(result->attrs["ops"]))) {
|
|
|
|
|
for (auto &prev_op : result->Get<GraphOps>("ops")) {
|
|
|
|
|
if (prev_op->Name() == prev_op_name) {
|
|
|
|
|
auto *dep_var = new DummyVarHandle();
|
|
|
|
|
prev_op->AddOutput(dep_var);
|
|
|
|
|
boost::any_cast<GraphDepVars *>(result->attrs["dep_vars"])
|
|
|
|
|
->emplace(dep_var);
|
|
|
|
|
result->Get<GraphDepVars>("dep_vars").emplace(dep_var);
|
|
|
|
|
op->AddInput(dep_var);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
@ -579,8 +557,7 @@ void MultiDevSSAGraphBuilder::CreateDistTrainOp(Graph *result,
|
|
|
|
|
|
|
|
|
|
CreateComputationalOp(result, op, op_dev_id);
|
|
|
|
|
if (op.Type() == "concat") {
|
|
|
|
|
ConnectOp(result,
|
|
|
|
|
boost::any_cast<GraphOps *>(result->attrs["ops"])->back().get(),
|
|
|
|
|
ConnectOp(result, result->Get<GraphOps>("ops").back().get(),
|
|
|
|
|
"fetch_barrier");
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
@ -615,22 +592,16 @@ void MultiDevSSAGraphBuilder::CreateRPCOp(Graph *result,
|
|
|
|
|
PADDLE_ENFORCE(op_dev_id != -1, "can not find the right place for rpc op: %s",
|
|
|
|
|
op.Type());
|
|
|
|
|
|
|
|
|
|
boost::any_cast<GraphOps *>(result->attrs["ops"])
|
|
|
|
|
->emplace_back(new RPCOpHandle(op, local_scopes_[op_dev_id], op.Type(),
|
|
|
|
|
places_[op_dev_id]));
|
|
|
|
|
result->Get<GraphOps>("ops").emplace_back(new RPCOpHandle(
|
|
|
|
|
op, local_scopes_[op_dev_id], op.Type(), places_[op_dev_id]));
|
|
|
|
|
|
|
|
|
|
if (op.Type() == "send_barrier") {
|
|
|
|
|
ConnectOp(result,
|
|
|
|
|
boost::any_cast<GraphOps *>(result->attrs["ops"])->back().get(),
|
|
|
|
|
"send");
|
|
|
|
|
ConnectOp(result, result->Get<GraphOps>("ops").back().get(), "send");
|
|
|
|
|
} else if (op.Type() == "recv") {
|
|
|
|
|
ConnectOp(result,
|
|
|
|
|
boost::any_cast<GraphOps *>(result->attrs["ops"])->back().get(),
|
|
|
|
|
ConnectOp(result, result->Get<GraphOps>("ops").back().get(),
|
|
|
|
|
"send_barrier");
|
|
|
|
|
} else if (op.Type() == "fetch_barrier") {
|
|
|
|
|
ConnectOp(result,
|
|
|
|
|
boost::any_cast<GraphOps *>(result->attrs["ops"])->back().get(),
|
|
|
|
|
"recv");
|
|
|
|
|
ConnectOp(result, result->Get<GraphOps>("ops").back().get(), "recv");
|
|
|
|
|
} else if (op.Type() == "send") {
|
|
|
|
|
// do nothing
|
|
|
|
|
} else {
|
|
|
|
|