|
|
|
@ -328,12 +328,16 @@ void MultiDevSSAGraphBuilder::SetCommunicationContext(
|
|
|
|
|
void MultiDevSSAGraphBuilder::CreateBroadcastOp(Graph *result,
|
|
|
|
|
const std::string &p_name,
|
|
|
|
|
size_t src_dev_id) const {
|
|
|
|
|
result->nodes.emplace_back(new ir::Node(ir::Node::Type::kOperation));
|
|
|
|
|
#ifdef PADDLE_WITH_CUDA
|
|
|
|
|
auto *op_handle = new BroadcastOpHandle(local_scopes_, places_, nccl_ctxs_);
|
|
|
|
|
auto *op_handle = new BroadcastOpHandle(result->nodes.back().get(),
|
|
|
|
|
local_scopes_, places_, nccl_ctxs_);
|
|
|
|
|
#else
|
|
|
|
|
auto *op_handle = new BroadcastOpHandle(local_scopes_, places_);
|
|
|
|
|
auto *op_handle =
|
|
|
|
|
new BroadcastOpHandle(result->nodes.back().get(), local_scopes_, places_);
|
|
|
|
|
#endif
|
|
|
|
|
result->Get<GraphOps>("ops").emplace_back(op_handle);
|
|
|
|
|
|
|
|
|
|
auto *in =
|
|
|
|
|
result->Get<GraphVars>("vars").at(src_dev_id).at(p_name).back().get();
|
|
|
|
|
op_handle->AddInput(in);
|
|
|
|
@ -341,8 +345,10 @@ void MultiDevSSAGraphBuilder::CreateBroadcastOp(Graph *result,
|
|
|
|
|
for (size_t i = 0; i < places_.size(); ++i) {
|
|
|
|
|
auto &p = places_[i];
|
|
|
|
|
SetCommunicationContext(op_handle, p);
|
|
|
|
|
result->nodes.emplace_back(new ir::Node(ir::Node::Type::kVariable));
|
|
|
|
|
auto &vars = result->Get<GraphVars>("vars").at(i).at(p_name);
|
|
|
|
|
auto *out_var = new VarHandle(vars.size(), i, p_name, p);
|
|
|
|
|
auto *out_var =
|
|
|
|
|
new VarHandle(result->nodes.back().get(), vars.size(), i, p_name, p);
|
|
|
|
|
vars.emplace_back(out_var);
|
|
|
|
|
op_handle->AddOutput(out_var);
|
|
|
|
|
}
|
|
|
|
@ -351,19 +357,21 @@ void MultiDevSSAGraphBuilder::CreateBroadcastOp(Graph *result,
|
|
|
|
|
void MultiDevSSAGraphBuilder::CreateComputationalOp(Graph *result,
|
|
|
|
|
const OpDesc &op,
|
|
|
|
|
int dev_id) const {
|
|
|
|
|
result->Get<GraphOps>("ops").emplace_back(
|
|
|
|
|
new ComputationOpHandle(op, local_scopes_[dev_id], places_[dev_id]));
|
|
|
|
|
result->nodes.emplace_back(new ir::Node(ir::Node::Type::kOperation));
|
|
|
|
|
result->Get<GraphOps>("ops").emplace_back(new ComputationOpHandle(
|
|
|
|
|
result->nodes.back().get(), op, local_scopes_[dev_id], places_[dev_id]));
|
|
|
|
|
CreateOpHandleIOs(result, op, dev_id);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void MultiDevSSAGraphBuilder::InsertAllReduceOp(Graph *result,
|
|
|
|
|
const std::string &og) const {
|
|
|
|
|
result->nodes.emplace_back(new ir::Node(ir::Node::Type::kOperation));
|
|
|
|
|
#ifdef PADDLE_WITH_CUDA
|
|
|
|
|
result->Get<GraphOps>("ops").emplace_back(
|
|
|
|
|
new AllReduceOpHandle(local_scopes_, places_, nccl_ctxs_));
|
|
|
|
|
result->Get<GraphOps>("ops").emplace_back(new AllReduceOpHandle(
|
|
|
|
|
result->nodes.back().get(), local_scopes_, places_, nccl_ctxs_));
|
|
|
|
|
#else
|
|
|
|
|
result->Get<GraphOps>("ops").emplace_back(
|
|
|
|
|
new AllReduceOpHandle(local_scopes_, places_));
|
|
|
|
|
result->Get<GraphOps>("ops").emplace_back(new AllReduceOpHandle(
|
|
|
|
|
result->nodes.back().get(), local_scopes_, places_));
|
|
|
|
|
#endif
|
|
|
|
|
auto *op_handle = result->Get<GraphOps>("ops").back().get();
|
|
|
|
|
|
|
|
|
@ -375,7 +383,8 @@ void MultiDevSSAGraphBuilder::InsertAllReduceOp(Graph *result,
|
|
|
|
|
auto &prev_grad = vars.back();
|
|
|
|
|
op_handle->AddInput(prev_grad.get());
|
|
|
|
|
|
|
|
|
|
auto var = new VarHandle(vars.size(), i, og, p);
|
|
|
|
|
result->nodes.emplace_back(new ir::Node(ir::Node::Type::kVariable));
|
|
|
|
|
auto var = new VarHandle(result->nodes.back().get(), vars.size(), i, og, p);
|
|
|
|
|
vars.emplace_back(var);
|
|
|
|
|
op_handle->AddOutput(var);
|
|
|
|
|
}
|
|
|
|
@ -383,12 +392,13 @@ void MultiDevSSAGraphBuilder::InsertAllReduceOp(Graph *result,
|
|
|
|
|
|
|
|
|
|
void MultiDevSSAGraphBuilder::InsertDataBalanceOp(
|
|
|
|
|
Graph *result, const std::vector<std::string> &datas) const {
|
|
|
|
|
result->nodes.emplace_back(new ir::Node(ir::Node::Type::kOperation));
|
|
|
|
|
#ifdef PADDLE_WITH_CUDA
|
|
|
|
|
result->Get<GraphOps>("ops").emplace_back(
|
|
|
|
|
new DataBalanceOpHandle(local_scopes_, places_, nccl_ctxs_));
|
|
|
|
|
result->Get<GraphOps>("ops").emplace_back(new DataBalanceOpHandle(
|
|
|
|
|
result->nodes.back().get(), local_scopes_, places_, nccl_ctxs_));
|
|
|
|
|
#else
|
|
|
|
|
result->Get<GraphOps>("ops").emplace_back(
|
|
|
|
|
new DataBalanceOpHandle(local_scopes_, places_));
|
|
|
|
|
result->Get<GraphOps>("ops").emplace_back(new DataBalanceOpHandle(
|
|
|
|
|
result->nodes.back().get(), local_scopes_, places_));
|
|
|
|
|
#endif
|
|
|
|
|
auto *op_handle = result->Get<GraphOps>("ops").back().get();
|
|
|
|
|
for (size_t i = 0; i < places_.size(); ++i) {
|
|
|
|
@ -398,7 +408,9 @@ void MultiDevSSAGraphBuilder::InsertDataBalanceOp(
|
|
|
|
|
auto &vars = result->Get<GraphVars>("vars")[i][d_name];
|
|
|
|
|
PADDLE_ENFORCE(!vars.empty());
|
|
|
|
|
op_handle->AddInput(vars.back().get());
|
|
|
|
|
auto var = new VarHandle(vars.size(), i, d_name, p);
|
|
|
|
|
result->nodes.emplace_back(new ir::Node(ir::Node::Type::kVariable));
|
|
|
|
|
auto var =
|
|
|
|
|
new VarHandle(result->nodes.back().get(), vars.size(), i, d_name, p);
|
|
|
|
|
vars.emplace_back(var);
|
|
|
|
|
op_handle->AddOutput(var);
|
|
|
|
|
}
|
|
|
|
@ -452,9 +464,9 @@ void MultiDevSSAGraphBuilder::CreateScaleLossGradOp(Graph *result) const {
|
|
|
|
|
auto *communication_dev_ctx =
|
|
|
|
|
platform::DeviceContextPool::Instance().Get(platform::CPUPlace());
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
auto *op_handle =
|
|
|
|
|
new ScaleLossGradOpHandle(local_scopes_.size(), local_scopes_[i],
|
|
|
|
|
result->nodes.emplace_back(new ir::Node(ir::Node::Type::kOperation));
|
|
|
|
|
auto *op_handle = new ScaleLossGradOpHandle(
|
|
|
|
|
result->nodes.back().get(), local_scopes_.size(), local_scopes_[i],
|
|
|
|
|
places_[i], communication_dev_ctx);
|
|
|
|
|
result->Get<GraphOps>("ops").emplace_back(op_handle);
|
|
|
|
|
|
|
|
|
@ -475,8 +487,9 @@ void MultiDevSSAGraphBuilder::CreateComputationalOps(Graph *result,
|
|
|
|
|
for (size_t scope_idx = 0; scope_idx < num_places; ++scope_idx) {
|
|
|
|
|
auto p = places_[scope_idx];
|
|
|
|
|
auto s = local_scopes_[scope_idx];
|
|
|
|
|
result->nodes.emplace_back(new ir::Node(ir::Node::Type::kOperation));
|
|
|
|
|
result->Get<GraphOps>("ops").emplace_back(
|
|
|
|
|
new ComputationOpHandle(op, s, p));
|
|
|
|
|
new ComputationOpHandle(result->nodes.back().get(), op, s, p));
|
|
|
|
|
CreateOpHandleIOs(result, op, scope_idx);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
@ -484,12 +497,13 @@ void MultiDevSSAGraphBuilder::CreateComputationalOps(Graph *result,
|
|
|
|
|
VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(Graph *result,
|
|
|
|
|
const std::string &og,
|
|
|
|
|
int dst_dev_id) const {
|
|
|
|
|
result->nodes.emplace_back(new ir::Node(ir::Node::Type::kOperation));
|
|
|
|
|
#ifdef PADDLE_WITH_CUDA
|
|
|
|
|
result->Get<GraphOps>("ops").emplace_back(
|
|
|
|
|
new ReduceOpHandle(local_scopes_, places_, nccl_ctxs_));
|
|
|
|
|
result->Get<GraphOps>("ops").emplace_back(new ReduceOpHandle(
|
|
|
|
|
result->nodes.back().get(), local_scopes_, places_, nccl_ctxs_));
|
|
|
|
|
#else
|
|
|
|
|
result->Get<GraphOps>("ops").emplace_back(
|
|
|
|
|
new ReduceOpHandle(local_scopes_, places_));
|
|
|
|
|
new ReduceOpHandle(result->nodes.back().get(), local_scopes_, places_));
|
|
|
|
|
#endif
|
|
|
|
|
auto *op_handle = result->Get<GraphOps>("ops").back().get();
|
|
|
|
|
|
|
|
|
@ -502,7 +516,9 @@ VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(Graph *result,
|
|
|
|
|
op_handle->AddInput(prev_grad.get());
|
|
|
|
|
}
|
|
|
|
|
auto &vars = result->Get<GraphVars>("vars")[dst_dev_id][og];
|
|
|
|
|
auto var = new VarHandle(vars.size(), dst_dev_id, og, places_[dst_dev_id]);
|
|
|
|
|
result->nodes.emplace_back(new ir::Node(ir::Node::Type::kVariable));
|
|
|
|
|
auto var = new VarHandle(result->nodes.back().get(), vars.size(), dst_dev_id,
|
|
|
|
|
og, places_[dst_dev_id]);
|
|
|
|
|
vars.emplace_back(var);
|
|
|
|
|
op_handle->AddOutput(var);
|
|
|
|
|
return var;
|
|
|
|
@ -514,7 +530,8 @@ void MultiDevSSAGraphBuilder::ConnectOp(Graph *result, OpHandleBase *op,
|
|
|
|
|
const std::string &prev_op_name) const {
|
|
|
|
|
for (auto &prev_op : result->Get<GraphOps>("ops")) {
|
|
|
|
|
if (prev_op->Name() == prev_op_name) {
|
|
|
|
|
auto *dep_var = new DummyVarHandle();
|
|
|
|
|
result->nodes.emplace_back(new ir::Node(ir::Node::Type::kVariable));
|
|
|
|
|
auto *dep_var = new DummyVarHandle(result->nodes.back().get());
|
|
|
|
|
prev_op->AddOutput(dep_var);
|
|
|
|
|
result->Get<GraphDepVars>("dep_vars").emplace(dep_var);
|
|
|
|
|
op->AddInput(dep_var);
|
|
|
|
@ -587,8 +604,10 @@ void MultiDevSSAGraphBuilder::CreateRPCOp(Graph *result,
|
|
|
|
|
PADDLE_ENFORCE(op_dev_id != -1, "can not find the right place for rpc op: %s",
|
|
|
|
|
op.Type());
|
|
|
|
|
|
|
|
|
|
result->Get<GraphOps>("ops").emplace_back(new RPCOpHandle(
|
|
|
|
|
op, local_scopes_[op_dev_id], op.Type(), places_[op_dev_id]));
|
|
|
|
|
result->nodes.emplace_back(new ir::Node(ir::Node::Type::kOperation));
|
|
|
|
|
result->Get<GraphOps>("ops").emplace_back(
|
|
|
|
|
new RPCOpHandle(result->nodes.back().get(), op, local_scopes_[op_dev_id],
|
|
|
|
|
op.Type(), places_[op_dev_id]));
|
|
|
|
|
|
|
|
|
|
if (op.Type() == "send_barrier") {
|
|
|
|
|
ConnectOp(result, result->Get<GraphOps>("ops").back().get(), "send");
|
|
|
|
|