|
|
|
@ -90,7 +90,7 @@ std::vector<std::string> MultiDevSSAGraphBuilder::FindDistTrainSendVars(
|
|
|
|
|
// since parameters are all in block 0,
|
|
|
|
|
// it's enough to only scan send ops in block 0
|
|
|
|
|
for (auto &node : nodes) {
|
|
|
|
|
if (!node->Op()) continue;
|
|
|
|
|
if (node->NodeType() != ir::Node::Type::kOperation) continue;
|
|
|
|
|
OpDesc *op = node->Op();
|
|
|
|
|
// TODO(Yancey1989): use a graceful method to find send op,
|
|
|
|
|
// instead of the the hard code string
|
|
|
|
@ -108,7 +108,7 @@ std::vector<std::string> MultiDevSSAGraphBuilder::FindDistTrainRecvVars(
|
|
|
|
|
const std::vector<std::unique_ptr<ir::Node>> &nodes) const {
|
|
|
|
|
std::vector<std::string> recv_vars;
|
|
|
|
|
for (auto &node : nodes) {
|
|
|
|
|
if (!node->Op()) continue;
|
|
|
|
|
if (node->NodeType() != ir::Node::Type::kOperation) continue;
|
|
|
|
|
OpDesc *op = node->Op();
|
|
|
|
|
// TODO(Yancey1989): use a graceful method to find recv op,
|
|
|
|
|
// instead of the hard code string
|
|
|
|
@ -149,10 +149,10 @@ bool MultiDevSSAGraphBuilder::IsDistTrainOp(
|
|
|
|
|
std::vector<std::string> input_var_names;
|
|
|
|
|
std::vector<std::string> output_var_names;
|
|
|
|
|
for (ir::Node *input : node->inputs) {
|
|
|
|
|
input_var_names.push_back(input->Var()->Name());
|
|
|
|
|
input_var_names.push_back(input->Name());
|
|
|
|
|
}
|
|
|
|
|
for (ir::Node *output : node->outputs) {
|
|
|
|
|
output_var_names.push_back(output->Var()->Name());
|
|
|
|
|
output_var_names.push_back(output->Name());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return checker(output_var_names, send_vars) ||
|
|
|
|
@ -181,13 +181,13 @@ size_t MultiDevSSAGraphBuilder::GetAppropriateDeviceID(
|
|
|
|
|
|
|
|
|
|
std::unique_ptr<Graph> MultiDevSSAGraphBuilder::Apply(
|
|
|
|
|
std::unique_ptr<Graph> graph) const {
|
|
|
|
|
// Rebuild the graph structure.
|
|
|
|
|
auto nodes = std::move(graph->nodes);
|
|
|
|
|
graph->nodes.clear();
|
|
|
|
|
LOG(ERROR) << "origin nodes count " << nodes.size();
|
|
|
|
|
|
|
|
|
|
for (auto &node : nodes) {
|
|
|
|
|
if (node->Var()) {
|
|
|
|
|
all_vars_.emplace(node->Var()->Name(), node->Var());
|
|
|
|
|
if (node->NodeType() == ir::Node::Type::kVariable) {
|
|
|
|
|
all_vars_.emplace(node->Name(), node->Var());
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -212,7 +212,7 @@ std::unique_ptr<Graph> MultiDevSSAGraphBuilder::Apply(
|
|
|
|
|
|
|
|
|
|
// TODO(panyx0718): FIXME: nodes should be sorted by "program" order.
|
|
|
|
|
for (auto &node : nodes) {
|
|
|
|
|
if (!node->Op()) continue;
|
|
|
|
|
if (node->NodeType() != ir::Node::Type::kOperation) continue;
|
|
|
|
|
if (boost::get<int>(
|
|
|
|
|
node->Op()->GetAttr(OpProtoAndCheckerMaker::OpRoleAttrName())) ==
|
|
|
|
|
static_cast<int>(OpRole::kRPC)) {
|
|
|
|
@ -235,7 +235,7 @@ std::unique_ptr<Graph> MultiDevSSAGraphBuilder::Apply(
|
|
|
|
|
if (op_dev_id != -1) { // This op only runs on one specific device.
|
|
|
|
|
CreateComputationalOp(&result, node.get(), op_dev_id);
|
|
|
|
|
for (ir::Node *n : node->outputs) {
|
|
|
|
|
var_name_on_devices_.emplace(n->Var()->Name(), op_dev_id);
|
|
|
|
|
var_name_on_devices_.emplace(n->Name(), op_dev_id);
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
// This op runs on all devices, and its output may have parameter's
|
|
|
|
@ -351,10 +351,10 @@ void MultiDevSSAGraphBuilder::CreateBroadcastOp(Graph *result,
|
|
|
|
|
const std::string &p_name,
|
|
|
|
|
size_t src_dev_id) const {
|
|
|
|
|
#ifdef PADDLE_WITH_CUDA
|
|
|
|
|
auto *op_handle = new BroadcastOpHandle(result->CreateOpNode(nullptr),
|
|
|
|
|
auto *op_handle = new BroadcastOpHandle(result->CreateEmptyNode("broadcast"),
|
|
|
|
|
local_scopes_, places_, nccl_ctxs_);
|
|
|
|
|
#else
|
|
|
|
|
auto *op_handle = new BroadcastOpHandle(result->CreateOpNode(nullptr),
|
|
|
|
|
auto *op_handle = new BroadcastOpHandle(result->CreateEmptyNode("broadcast"),
|
|
|
|
|
local_scopes_, places_);
|
|
|
|
|
#endif
|
|
|
|
|
result->Get<GraphOps>("ops").emplace_back(op_handle);
|
|
|
|
@ -367,8 +367,8 @@ void MultiDevSSAGraphBuilder::CreateBroadcastOp(Graph *result,
|
|
|
|
|
auto &p = places_[i];
|
|
|
|
|
SetCommunicationContext(op_handle, p);
|
|
|
|
|
auto &vars = result->Get<GraphVars>("vars").at(i).at(p_name);
|
|
|
|
|
auto *out_var =
|
|
|
|
|
new VarHandle(result->CreateVarNode(p_name), vars.size(), i, p_name, p);
|
|
|
|
|
auto *out_var = new VarHandle(result->CreateEmptyNode(p_name), vars.size(),
|
|
|
|
|
i, p_name, p);
|
|
|
|
|
vars.emplace_back(out_var);
|
|
|
|
|
op_handle->AddOutput(out_var);
|
|
|
|
|
}
|
|
|
|
@ -378,7 +378,7 @@ void MultiDevSSAGraphBuilder::CreateComputationalOp(Graph *result,
|
|
|
|
|
ir::Node *node,
|
|
|
|
|
int dev_id) const {
|
|
|
|
|
result->Get<GraphOps>("ops").emplace_back(
|
|
|
|
|
new ComputationOpHandle(result->CreateOpNode(node->Op()), *node->Op(),
|
|
|
|
|
new ComputationOpHandle(result->CreateOpNode(node->Op()),
|
|
|
|
|
local_scopes_[dev_id], places_[dev_id]));
|
|
|
|
|
CreateOpHandleIOs(result, node, dev_id);
|
|
|
|
|
}
|
|
|
|
@ -386,11 +386,12 @@ void MultiDevSSAGraphBuilder::CreateComputationalOp(Graph *result,
|
|
|
|
|
void MultiDevSSAGraphBuilder::InsertAllReduceOp(Graph *result,
|
|
|
|
|
const std::string &og) const {
|
|
|
|
|
#ifdef PADDLE_WITH_CUDA
|
|
|
|
|
result->Get<GraphOps>("ops").emplace_back(new AllReduceOpHandle(
|
|
|
|
|
result->CreateOpNode(nullptr), local_scopes_, places_, nccl_ctxs_));
|
|
|
|
|
result->Get<GraphOps>("ops").emplace_back(
|
|
|
|
|
new AllReduceOpHandle(result->CreateEmptyNode("allreduce"), local_scopes_,
|
|
|
|
|
places_, nccl_ctxs_));
|
|
|
|
|
#else
|
|
|
|
|
result->Get<GraphOps>("ops").emplace_back(new AllReduceOpHandle(
|
|
|
|
|
result->CreateOpNode(nullptr), local_scopes_, places_));
|
|
|
|
|
result->CreateEmptyNode("allreduce"), local_scopes_, places_));
|
|
|
|
|
#endif
|
|
|
|
|
auto *op_handle = result->Get<GraphOps>("ops").back().get();
|
|
|
|
|
|
|
|
|
@ -402,7 +403,8 @@ void MultiDevSSAGraphBuilder::InsertAllReduceOp(Graph *result,
|
|
|
|
|
auto &prev_grad = vars.back();
|
|
|
|
|
op_handle->AddInput(prev_grad.get());
|
|
|
|
|
|
|
|
|
|
auto var = new VarHandle(result->CreateVarNode(og), vars.size(), i, og, p);
|
|
|
|
|
auto var =
|
|
|
|
|
new VarHandle(result->CreateEmptyNode(og), vars.size(), i, og, p);
|
|
|
|
|
vars.emplace_back(var);
|
|
|
|
|
op_handle->AddOutput(var);
|
|
|
|
|
}
|
|
|
|
@ -411,11 +413,12 @@ void MultiDevSSAGraphBuilder::InsertAllReduceOp(Graph *result,
|
|
|
|
|
void MultiDevSSAGraphBuilder::InsertDataBalanceOp(
|
|
|
|
|
Graph *result, const std::vector<std::string> &datas) const {
|
|
|
|
|
#ifdef PADDLE_WITH_CUDA
|
|
|
|
|
result->Get<GraphOps>("ops").emplace_back(new DataBalanceOpHandle(
|
|
|
|
|
result->CreateOpNode(nullptr), local_scopes_, places_, nccl_ctxs_));
|
|
|
|
|
result->Get<GraphOps>("ops").emplace_back(
|
|
|
|
|
new DataBalanceOpHandle(result->CreateEmptyNode("data_balance"),
|
|
|
|
|
local_scopes_, places_, nccl_ctxs_));
|
|
|
|
|
#else
|
|
|
|
|
result->Get<GraphOps>("ops").emplace_back(new DataBalanceOpHandle(
|
|
|
|
|
result->CreateOpNode(nullptr), local_scopes_, places_));
|
|
|
|
|
result->CreateEmptyNode("data_balance"), local_scopes_, places_));
|
|
|
|
|
#endif
|
|
|
|
|
auto *op_handle = result->Get<GraphOps>("ops").back().get();
|
|
|
|
|
for (size_t i = 0; i < places_.size(); ++i) {
|
|
|
|
@ -425,7 +428,7 @@ void MultiDevSSAGraphBuilder::InsertDataBalanceOp(
|
|
|
|
|
auto &vars = result->Get<GraphVars>("vars")[i][d_name];
|
|
|
|
|
PADDLE_ENFORCE(!vars.empty());
|
|
|
|
|
op_handle->AddInput(vars.back().get());
|
|
|
|
|
auto var = new VarHandle(result->CreateVarNode(d_name), vars.size(), i,
|
|
|
|
|
auto var = new VarHandle(result->CreateEmptyNode(d_name), vars.size(), i,
|
|
|
|
|
d_name, p);
|
|
|
|
|
vars.emplace_back(var);
|
|
|
|
|
op_handle->AddOutput(var);
|
|
|
|
@ -455,12 +458,12 @@ int MultiDevSSAGraphBuilder::GetOpDeviceID(ir::Node *node) const {
|
|
|
|
|
return -1;
|
|
|
|
|
}
|
|
|
|
|
auto param_grad = boost::get<std::vector<std::string>>(
|
|
|
|
|
node->Op()->.GetAttr(OpProtoAndCheckerMaker::OpRoleVarAttrName()));
|
|
|
|
|
node->Op()->GetAttr(OpProtoAndCheckerMaker::OpRoleVarAttrName()));
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_EQ(param_grad.size(), 2U);
|
|
|
|
|
int dev_id = GetVarDeviceID(param_grad[1]);
|
|
|
|
|
PADDLE_ENFORCE_NE(dev_id, -1, "dev_id should not be -1.[%s, %s]", op.Type(),
|
|
|
|
|
param_grad[0]);
|
|
|
|
|
PADDLE_ENFORCE_NE(dev_id, -1, "dev_id should not be -1.[%s, %s]",
|
|
|
|
|
node->Op()->Type(), param_grad[0]);
|
|
|
|
|
return dev_id;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -481,8 +484,8 @@ void MultiDevSSAGraphBuilder::CreateScaleLossGradOp(Graph *result) const {
|
|
|
|
|
platform::DeviceContextPool::Instance().Get(platform::CPUPlace());
|
|
|
|
|
#endif
|
|
|
|
|
auto *op_handle = new ScaleLossGradOpHandle(
|
|
|
|
|
result->CreateOpNode(nullptr), local_scopes_.size(), local_scopes_[i],
|
|
|
|
|
places_[i], communication_dev_ctx);
|
|
|
|
|
result->CreateEmptyNode("scale_loss_grad"), local_scopes_.size(),
|
|
|
|
|
local_scopes_[i], places_[i], communication_dev_ctx);
|
|
|
|
|
result->Get<GraphOps>("ops").emplace_back(op_handle);
|
|
|
|
|
|
|
|
|
|
// FIXME: Currently ScaleLossGradOp only use device_count as scale
|
|
|
|
@ -495,7 +498,7 @@ void MultiDevSSAGraphBuilder::CreateScaleLossGradOp(Graph *result) const {
|
|
|
|
|
const std::string grad_var_name = GradVarName(loss_var_name_);
|
|
|
|
|
auto &vars = result->Get<GraphVars>("vars")[i][grad_var_name];
|
|
|
|
|
size_t version = vars.size();
|
|
|
|
|
auto var = new VarHandle(result->CreateVarNode(grad_var_name), version, i,
|
|
|
|
|
auto var = new VarHandle(result->CreateEmptyNode(grad_var_name), version, i,
|
|
|
|
|
grad_var_name, places_[i]);
|
|
|
|
|
vars.emplace_back(var);
|
|
|
|
|
op_handle->AddOutput(var);
|
|
|
|
@ -508,8 +511,8 @@ void MultiDevSSAGraphBuilder::CreateComputationalOps(Graph *result,
|
|
|
|
|
for (size_t scope_idx = 0; scope_idx < num_places; ++scope_idx) {
|
|
|
|
|
auto p = places_[scope_idx];
|
|
|
|
|
auto s = local_scopes_[scope_idx];
|
|
|
|
|
result->Get<GraphOps>("ops").emplace_back(new ComputationOpHandle(
|
|
|
|
|
result->CreateOpNode(node->Op()), *node->Op(), s, p));
|
|
|
|
|
result->Get<GraphOps>("ops").emplace_back(
|
|
|
|
|
new ComputationOpHandle(result->CreateOpNode(node->Op()), s, p));
|
|
|
|
|
CreateOpHandleIOs(result, node, scope_idx);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
@ -519,10 +522,10 @@ VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(Graph *result,
|
|
|
|
|
int dst_dev_id) const {
|
|
|
|
|
#ifdef PADDLE_WITH_CUDA
|
|
|
|
|
result->Get<GraphOps>("ops").emplace_back(new ReduceOpHandle(
|
|
|
|
|
result->CreateOpNode(nullptr), local_scopes_, places_, nccl_ctxs_));
|
|
|
|
|
result->CreateEmptyNode("reduce"), local_scopes_, places_, nccl_ctxs_));
|
|
|
|
|
#else
|
|
|
|
|
result->Get<GraphOps>("ops").emplace_back(new ReduceOpHandle(
|
|
|
|
|
result->CreateOpNode(nullptr), local_scopes_, places_));
|
|
|
|
|
result->CreateEmptyNode("reduce"), local_scopes_, places_));
|
|
|
|
|
#endif
|
|
|
|
|
auto *op_handle = result->Get<GraphOps>("ops").back().get();
|
|
|
|
|
|
|
|
|
@ -535,7 +538,7 @@ VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(Graph *result,
|
|
|
|
|
op_handle->AddInput(prev_grad.get());
|
|
|
|
|
}
|
|
|
|
|
auto &vars = result->Get<GraphVars>("vars")[dst_dev_id][og];
|
|
|
|
|
auto var = new VarHandle(result->CreateVarNode(og), vars.size(), dst_dev_id,
|
|
|
|
|
auto var = new VarHandle(result->CreateEmptyNode(og), vars.size(), dst_dev_id,
|
|
|
|
|
og, places_[dst_dev_id]);
|
|
|
|
|
vars.emplace_back(var);
|
|
|
|
|
op_handle->AddOutput(var);
|
|
|
|
@ -548,7 +551,7 @@ void MultiDevSSAGraphBuilder::ConnectOp(Graph *result, OpHandleBase *op,
|
|
|
|
|
const std::string &prev_op_name) const {
|
|
|
|
|
for (auto &prev_op : result->Get<GraphOps>("ops")) {
|
|
|
|
|
if (prev_op->Name() == prev_op_name) {
|
|
|
|
|
auto *dep_var = new DummyVarHandle(result->CreateVarNode("dummy"));
|
|
|
|
|
auto *dep_var = new DummyVarHandle(result->CreateEmptyNode("dummy"));
|
|
|
|
|
prev_op->AddOutput(dep_var);
|
|
|
|
|
result->Get<GraphDepVars>("dep_vars").emplace(dep_var);
|
|
|
|
|
op->AddInput(dep_var);
|
|
|
|
@ -562,10 +565,10 @@ void MultiDevSSAGraphBuilder::CreateDistTrainOp(Graph *result,
|
|
|
|
|
std::vector<std::string> input_var_names;
|
|
|
|
|
std::vector<std::string> output_var_names;
|
|
|
|
|
for (ir::Node *input : node->inputs) {
|
|
|
|
|
input_var_names.push_back(input->Var()->Name());
|
|
|
|
|
input_var_names.push_back(input->Name());
|
|
|
|
|
}
|
|
|
|
|
for (ir::Node *output : node->outputs) {
|
|
|
|
|
output_var_names.push_back(output->Var()->Name());
|
|
|
|
|
output_var_names.push_back(output->Name());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (node->Op()->Type() == "split_byref" ||
|
|
|
|
@ -606,16 +609,16 @@ void MultiDevSSAGraphBuilder::CreateDistTrainOp(Graph *result,
|
|
|
|
|
void MultiDevSSAGraphBuilder::CreateRPCOp(Graph *result, ir::Node *node) const {
|
|
|
|
|
int op_dev_id = -1;
|
|
|
|
|
if (node->Op()->Type() == "send") {
|
|
|
|
|
op_dev_id = GetVarDeviceID(node->inputs[0]->Var()->Name());
|
|
|
|
|
op_dev_id = GetVarDeviceID(node->inputs[0]->Name());
|
|
|
|
|
// the variable name which contains .block means it was splited by
|
|
|
|
|
// split_byref op
|
|
|
|
|
// so that we can balance the variable blocks to all the pserver
|
|
|
|
|
// instances.
|
|
|
|
|
if (strategy_.reduce_ == BuildStrategy::ReduceStrategy::kAllReduce &&
|
|
|
|
|
node->inputs[0]->Var()->Name().find(".block") == std::string::npos) {
|
|
|
|
|
node->inputs[0]->Name().find(".block") == std::string::npos) {
|
|
|
|
|
std::vector<std::string> input_var_names;
|
|
|
|
|
for (ir::Node *n : node->inputs) {
|
|
|
|
|
input_var_names.push_back(n->Var()->Name());
|
|
|
|
|
input_var_names.push_back(n->Name());
|
|
|
|
|
}
|
|
|
|
|
op_dev_id = GetAppropriateDeviceID(input_var_names);
|
|
|
|
|
for (auto &varname : input_var_names) {
|
|
|
|
@ -625,7 +628,7 @@ void MultiDevSSAGraphBuilder::CreateRPCOp(Graph *result, ir::Node *node) const {
|
|
|
|
|
} else if (node->Op()->Type() == "recv") {
|
|
|
|
|
std::vector<std::string> output_var_names;
|
|
|
|
|
for (ir::Node *n : node->outputs) {
|
|
|
|
|
output_var_names.push_back(n->Var()->Name());
|
|
|
|
|
output_var_names.push_back(n->Name());
|
|
|
|
|
}
|
|
|
|
|
op_dev_id = GetAppropriateDeviceID(output_var_names);
|
|
|
|
|
for (auto &varname : output_var_names) {
|
|
|
|
|