|
|
|
@ -80,7 +80,14 @@ void MultiDevSSAGraphBuilder::CreateOpHandleIOs(Graph *result, ir::Node *node,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
for (ir::Node *output : node->outputs) {
|
|
|
|
|
CreateOpOutput(result, op_handle, output, p, place_id);
|
|
|
|
|
ir::Node *new_node = nullptr;
|
|
|
|
|
if (output->Var()) {
|
|
|
|
|
new_node = result->CreateVarNode(output->Var());
|
|
|
|
|
} else {
|
|
|
|
|
new_node =
|
|
|
|
|
result->CreateEmptyNode(output->Name(), ir::Node::Type::kVariable);
|
|
|
|
|
}
|
|
|
|
|
CreateOpOutput(result, op_handle, new_node, p, place_id);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -246,7 +253,7 @@ std::unique_ptr<Graph> MultiDevSSAGraphBuilder::Apply(
|
|
|
|
|
if (node->Op()->Type() == "read" && strategy_.enable_data_balance_) {
|
|
|
|
|
node->Op()->SetAttr("throw_eof_exp", false);
|
|
|
|
|
CreateComputationalOps(&result, node.get(), places_.size());
|
|
|
|
|
// TODO(panyx0718): builder shouldn't depend on the out logic of
|
|
|
|
|
// TODO(paddle-dev): builder shouldn't depend on the out logic of
|
|
|
|
|
// a specific op.
|
|
|
|
|
const auto &data_var_names = node->Op()->Output("Out");
|
|
|
|
|
InsertDataBalanceOp(&result, data_var_names);
|
|
|
|
@ -354,11 +361,13 @@ void MultiDevSSAGraphBuilder::CreateBroadcastOp(Graph *result,
|
|
|
|
|
const std::string &p_name,
|
|
|
|
|
size_t src_dev_id) const {
|
|
|
|
|
#ifdef PADDLE_WITH_CUDA
|
|
|
|
|
auto *op_handle = new BroadcastOpHandle(result->CreateEmptyNode("broadcast"),
|
|
|
|
|
local_scopes_, places_, nccl_ctxs_);
|
|
|
|
|
auto *op_handle = new BroadcastOpHandle(
|
|
|
|
|
result->CreateEmptyNode("broadcast", ir::Node::Type::kOperation),
|
|
|
|
|
local_scopes_, places_, nccl_ctxs_);
|
|
|
|
|
#else
|
|
|
|
|
auto *op_handle = new BroadcastOpHandle(result->CreateEmptyNode("broadcast"),
|
|
|
|
|
local_scopes_, places_);
|
|
|
|
|
auto *op_handle = new BroadcastOpHandle(
|
|
|
|
|
result->CreateEmptyNode("broadcast", ir::Node::Type::kOperation),
|
|
|
|
|
local_scopes_, places_);
|
|
|
|
|
#endif
|
|
|
|
|
result->Get<GraphOps>("ops").emplace_back(op_handle);
|
|
|
|
|
|
|
|
|
@ -370,8 +379,9 @@ void MultiDevSSAGraphBuilder::CreateBroadcastOp(Graph *result,
|
|
|
|
|
auto &p = places_[i];
|
|
|
|
|
SetCommunicationContext(op_handle, p);
|
|
|
|
|
auto &vars = result->Get<GraphVars>("vars").at(i).at(p_name);
|
|
|
|
|
auto *out_var = new VarHandle(result->CreateEmptyNode(p_name), vars.size(),
|
|
|
|
|
i, p_name, p);
|
|
|
|
|
auto *out_var = new VarHandle(
|
|
|
|
|
result->CreateEmptyNode(p_name, ir::Node::Type::kVariable), vars.size(),
|
|
|
|
|
i, p_name, p);
|
|
|
|
|
vars.emplace_back(out_var);
|
|
|
|
|
op_handle->AddOutput(out_var);
|
|
|
|
|
}
|
|
|
|
@ -389,12 +399,13 @@ void MultiDevSSAGraphBuilder::CreateComputationalOp(Graph *result,
|
|
|
|
|
void MultiDevSSAGraphBuilder::InsertAllReduceOp(Graph *result,
|
|
|
|
|
const std::string &og) const {
|
|
|
|
|
#ifdef PADDLE_WITH_CUDA
|
|
|
|
|
result->Get<GraphOps>("ops").emplace_back(
|
|
|
|
|
new AllReduceOpHandle(result->CreateEmptyNode("allreduce"), local_scopes_,
|
|
|
|
|
places_, nccl_ctxs_));
|
|
|
|
|
result->Get<GraphOps>("ops").emplace_back(new AllReduceOpHandle(
|
|
|
|
|
result->CreateEmptyNode("allreduce", ir::Node::Type::kOperation),
|
|
|
|
|
local_scopes_, places_, nccl_ctxs_));
|
|
|
|
|
#else
|
|
|
|
|
result->Get<GraphOps>("ops").emplace_back(new AllReduceOpHandle(
|
|
|
|
|
result->CreateEmptyNode("allreduce"), local_scopes_, places_));
|
|
|
|
|
result->CreateEmptyNode("allreduce", ir::Node::Type::kOperation),
|
|
|
|
|
local_scopes_, places_));
|
|
|
|
|
#endif
|
|
|
|
|
auto *op_handle = result->Get<GraphOps>("ops").back().get();
|
|
|
|
|
|
|
|
|
@ -407,7 +418,8 @@ void MultiDevSSAGraphBuilder::InsertAllReduceOp(Graph *result,
|
|
|
|
|
op_handle->AddInput(prev_grad.get());
|
|
|
|
|
|
|
|
|
|
auto var =
|
|
|
|
|
new VarHandle(result->CreateEmptyNode(og), vars.size(), i, og, p);
|
|
|
|
|
new VarHandle(result->CreateEmptyNode(og, ir::Node::Type::kVariable),
|
|
|
|
|
vars.size(), i, og, p);
|
|
|
|
|
vars.emplace_back(var);
|
|
|
|
|
op_handle->AddOutput(var);
|
|
|
|
|
}
|
|
|
|
@ -416,12 +428,13 @@ void MultiDevSSAGraphBuilder::InsertAllReduceOp(Graph *result,
|
|
|
|
|
void MultiDevSSAGraphBuilder::InsertDataBalanceOp(
|
|
|
|
|
Graph *result, const std::vector<std::string> &datas) const {
|
|
|
|
|
#ifdef PADDLE_WITH_CUDA
|
|
|
|
|
result->Get<GraphOps>("ops").emplace_back(
|
|
|
|
|
new DataBalanceOpHandle(result->CreateEmptyNode("data_balance"),
|
|
|
|
|
local_scopes_, places_, nccl_ctxs_));
|
|
|
|
|
result->Get<GraphOps>("ops").emplace_back(new DataBalanceOpHandle(
|
|
|
|
|
result->CreateEmptyNode("data_balance", ir::Node::Type::kOperation),
|
|
|
|
|
local_scopes_, places_, nccl_ctxs_));
|
|
|
|
|
#else
|
|
|
|
|
result->Get<GraphOps>("ops").emplace_back(new DataBalanceOpHandle(
|
|
|
|
|
result->CreateEmptyNode("data_balance"), local_scopes_, places_));
|
|
|
|
|
result->CreateEmptyNode("data_balance", ir::Node::Type::kOperation),
|
|
|
|
|
local_scopes_, places_));
|
|
|
|
|
#endif
|
|
|
|
|
auto *op_handle = result->Get<GraphOps>("ops").back().get();
|
|
|
|
|
for (size_t i = 0; i < places_.size(); ++i) {
|
|
|
|
@ -431,8 +444,9 @@ void MultiDevSSAGraphBuilder::InsertDataBalanceOp(
|
|
|
|
|
auto &vars = result->Get<GraphVars>("vars")[i][d_name];
|
|
|
|
|
PADDLE_ENFORCE(!vars.empty());
|
|
|
|
|
op_handle->AddInput(vars.back().get());
|
|
|
|
|
auto var = new VarHandle(result->CreateEmptyNode(d_name), vars.size(), i,
|
|
|
|
|
d_name, p);
|
|
|
|
|
auto var = new VarHandle(
|
|
|
|
|
result->CreateEmptyNode(d_name, ir::Node::Type::kVariable),
|
|
|
|
|
vars.size(), i, d_name, p);
|
|
|
|
|
vars.emplace_back(var);
|
|
|
|
|
op_handle->AddOutput(var);
|
|
|
|
|
}
|
|
|
|
@ -487,8 +501,9 @@ void MultiDevSSAGraphBuilder::CreateScaleLossGradOp(Graph *result) const {
|
|
|
|
|
platform::DeviceContextPool::Instance().Get(platform::CPUPlace());
|
|
|
|
|
#endif
|
|
|
|
|
auto *op_handle = new ScaleLossGradOpHandle(
|
|
|
|
|
result->CreateEmptyNode("scale_loss_grad"), local_scopes_.size(),
|
|
|
|
|
local_scopes_[i], places_[i], communication_dev_ctx);
|
|
|
|
|
result->CreateEmptyNode("scale_loss_grad", ir::Node::Type::kOperation),
|
|
|
|
|
local_scopes_.size(), local_scopes_[i], places_[i],
|
|
|
|
|
communication_dev_ctx);
|
|
|
|
|
result->Get<GraphOps>("ops").emplace_back(op_handle);
|
|
|
|
|
|
|
|
|
|
// FIXME: Currently ScaleLossGradOp only use device_count as scale
|
|
|
|
@ -497,14 +512,10 @@ void MultiDevSSAGraphBuilder::CreateScaleLossGradOp(Graph *result) const {
|
|
|
|
|
// loss->pending_ops_.emplace_back(op_handle);
|
|
|
|
|
// op_handle->inputs_.emplace_back(loss);
|
|
|
|
|
|
|
|
|
|
// TODO(panyx0718): GradVarName(loss_var_name_)
|
|
|
|
|
const std::string grad_var_name = GradVarName(loss_var_name_);
|
|
|
|
|
auto &vars = result->Get<GraphVars>("vars")[i][grad_var_name];
|
|
|
|
|
size_t version = vars.size();
|
|
|
|
|
auto var = new VarHandle(result->CreateEmptyNode(grad_var_name), version, i,
|
|
|
|
|
grad_var_name, places_[i]);
|
|
|
|
|
vars.emplace_back(var);
|
|
|
|
|
op_handle->AddOutput(var);
|
|
|
|
|
CreateOpOutput(result, op_handle,
|
|
|
|
|
result->CreateEmptyNode(GradVarName(loss_var_name_),
|
|
|
|
|
ir::Node::Type::kVariable),
|
|
|
|
|
places_[i], i);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -525,10 +536,12 @@ VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(Graph *result,
|
|
|
|
|
int dst_dev_id) const {
|
|
|
|
|
#ifdef PADDLE_WITH_CUDA
|
|
|
|
|
result->Get<GraphOps>("ops").emplace_back(new ReduceOpHandle(
|
|
|
|
|
result->CreateEmptyNode("reduce"), local_scopes_, places_, nccl_ctxs_));
|
|
|
|
|
result->CreateEmptyNode("reduce", ir::Node::Type::kOperation),
|
|
|
|
|
local_scopes_, places_, nccl_ctxs_));
|
|
|
|
|
#else
|
|
|
|
|
result->Get<GraphOps>("ops").emplace_back(new ReduceOpHandle(
|
|
|
|
|
result->CreateEmptyNode("reduce"), local_scopes_, places_));
|
|
|
|
|
result->CreateEmptyNode("reduce", ir::Node::Type::kOperation),
|
|
|
|
|
local_scopes_, places_));
|
|
|
|
|
#endif
|
|
|
|
|
auto *op_handle = result->Get<GraphOps>("ops").back().get();
|
|
|
|
|
|
|
|
|
@ -541,8 +554,9 @@ VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(Graph *result,
|
|
|
|
|
op_handle->AddInput(prev_grad.get());
|
|
|
|
|
}
|
|
|
|
|
auto &vars = result->Get<GraphVars>("vars")[dst_dev_id][og];
|
|
|
|
|
auto var = new VarHandle(result->CreateEmptyNode(og), vars.size(), dst_dev_id,
|
|
|
|
|
og, places_[dst_dev_id]);
|
|
|
|
|
auto var =
|
|
|
|
|
new VarHandle(result->CreateEmptyNode(og, ir::Node::Type::kVariable),
|
|
|
|
|
vars.size(), dst_dev_id, og, places_[dst_dev_id]);
|
|
|
|
|
vars.emplace_back(var);
|
|
|
|
|
op_handle->AddOutput(var);
|
|
|
|
|
return var;
|
|
|
|
@ -554,7 +568,8 @@ void MultiDevSSAGraphBuilder::ConnectOp(Graph *result, OpHandleBase *op,
|
|
|
|
|
const std::string &prev_op_name) const {
|
|
|
|
|
for (auto &prev_op : result->Get<GraphOps>("ops")) {
|
|
|
|
|
if (prev_op->Name() == prev_op_name) {
|
|
|
|
|
auto *dep_var = new DummyVarHandle(result->CreateEmptyNode("dummy"));
|
|
|
|
|
auto *dep_var = new DummyVarHandle(
|
|
|
|
|
result->CreateEmptyNode("dummy", ir::Node::Type::kVariable));
|
|
|
|
|
prev_op->AddOutput(dep_var);
|
|
|
|
|
result->Get<GraphDepVars>("dep_vars").emplace(dep_var);
|
|
|
|
|
op->AddInput(dep_var);
|
|
|
|
|