|
|
|
@ -348,14 +348,31 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::ApplyImpl(
|
|
|
|
|
|
|
|
|
|
size_t cur_device_id = 0;
|
|
|
|
|
bool is_forwarding = true;
|
|
|
|
|
bool is_dist_train = false;
|
|
|
|
|
|
|
|
|
|
for (ir::Node *node : sorted_ops) {
|
|
|
|
|
if (boost::get<int>(
|
|
|
|
|
node->Op()->GetAttr(OpProtoAndCheckerMaker::OpRoleAttrName())) ==
|
|
|
|
|
static_cast<int>(OpRole::kRPC)) {
|
|
|
|
|
CreateRPCOp(&result, node);
|
|
|
|
|
int op_dev_id = CreateRPCOp(&result, node);
|
|
|
|
|
PADDLE_ENFORCE(op_dev_id != -1,
|
|
|
|
|
"Can not schedule the RPC operator to the right place.");
|
|
|
|
|
if (node->Op()->Type() == "recv") {
|
|
|
|
|
auto recv_vars_attr =
|
|
|
|
|
boost::get<std::vector<std::string>>(node->Op()->GetNullableAttr(
|
|
|
|
|
OpProtoAndCheckerMaker::OpRoleVarAttrName()));
|
|
|
|
|
PADDLE_ENFORCE(recv_vars_attr.size() == 2UL); // [parameter, gradient]
|
|
|
|
|
if (recv_vars_attr[0].find(".block") == std::string::npos) {
|
|
|
|
|
bcast_var_name_set[op_dev_id].emplace(recv_vars_attr[0]);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
is_dist_train = true;
|
|
|
|
|
} else if (IsDistTrainOp(node, send_vars, recv_vars)) {
|
|
|
|
|
CreateDistTrainOp(&result, node);
|
|
|
|
|
int op_dev_id = CreateDistTrainOp(&result, node);
|
|
|
|
|
if (node->Op()->Type() == "concat") {
|
|
|
|
|
auto origin_param_name = node->Op()->OutputArgumentNames()[0];
|
|
|
|
|
bcast_var_name_set[op_dev_id].emplace(origin_param_name);
|
|
|
|
|
}
|
|
|
|
|
} else if (IsScaleLossOp(node)) {
|
|
|
|
|
// user can customize loss@grad if not use_default_grad_scale_
|
|
|
|
|
if (strategy_.gradient_scale_ !=
|
|
|
|
@ -414,7 +431,9 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::ApplyImpl(
|
|
|
|
|
CreateReduceOp(&result, g_name, cur_device_id);
|
|
|
|
|
graph->Get<ShardedVarDevice>(kShardedVarDevice)
|
|
|
|
|
.emplace(g_name, cur_device_id);
|
|
|
|
|
bcast_var_name_set[cur_device_id].emplace(p_name);
|
|
|
|
|
if (!is_dist_train) {
|
|
|
|
|
bcast_var_name_set[cur_device_id].emplace(p_name);
|
|
|
|
|
}
|
|
|
|
|
break;
|
|
|
|
|
case BuildStrategy::ReduceStrategy::kAllReduce:
|
|
|
|
|
if (IsSparseGradient(g_name)) {
|
|
|
|
@ -436,14 +455,19 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::ApplyImpl(
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool use_gpu = false;
|
|
|
|
|
#ifdef PADDLE_WITH_CUDA
|
|
|
|
|
use_gpu = nccl_ctxs_ != nullptr;
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
if (use_gpu && strategy_.reduce_ == BuildStrategy::ReduceStrategy::kReduce) {
|
|
|
|
|
// Insert BCast Ops
|
|
|
|
|
// Insert broadcast operators principle:
|
|
|
|
|
// 1. Broadcast optimized parameters in Reduce strategy;
|
|
|
|
|
// 2. No need broadcast optimized parameters in AllReduce strategy because of
|
|
|
|
|
// the optimization sub-graph would be run on every GPU;
|
|
|
|
|
// 3. Allways broadcast received parameters in Distribute Training.
|
|
|
|
|
if ((use_gpu &&
|
|
|
|
|
strategy_.reduce_ == BuildStrategy::ReduceStrategy::kReduce) ||
|
|
|
|
|
is_dist_train) {
|
|
|
|
|
for (size_t dev_id = 0; dev_id < bcast_var_name_set.size(); ++dev_id) {
|
|
|
|
|
auto &to_bcast_set = bcast_var_name_set[dev_id];
|
|
|
|
|
for (auto &bcast_name : to_bcast_set) {
|
|
|
|
@ -675,8 +699,8 @@ VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(ir::Graph *result,
|
|
|
|
|
return var;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void MultiDevSSAGraphBuilder::CreateDistTrainOp(ir::Graph *result,
|
|
|
|
|
ir::Node *node) const {
|
|
|
|
|
int MultiDevSSAGraphBuilder::CreateDistTrainOp(ir::Graph *result,
|
|
|
|
|
ir::Node *node) const {
|
|
|
|
|
int op_dev_id = -1;
|
|
|
|
|
std::vector<std::string> input_var_names;
|
|
|
|
|
std::vector<std::string> output_var_names;
|
|
|
|
@ -719,6 +743,7 @@ void MultiDevSSAGraphBuilder::CreateDistTrainOp(ir::Graph *result,
|
|
|
|
|
node->Op()->Type());
|
|
|
|
|
|
|
|
|
|
CreateComputationalOp(result, node, op_dev_id);
|
|
|
|
|
return op_dev_id;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void SetOpInputsAllPlaces(ir::Graph *result, ir::Node *node, int num_places) {
|
|
|
|
@ -737,8 +762,8 @@ void SetOpInputsAllPlaces(ir::Graph *result, ir::Node *node, int num_places) {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Create RPC related op handles that connects its in ops and out ops.
|
|
|
|
|
void MultiDevSSAGraphBuilder::CreateRPCOp(ir::Graph *result,
|
|
|
|
|
ir::Node *node) const {
|
|
|
|
|
int MultiDevSSAGraphBuilder::CreateRPCOp(ir::Graph *result,
|
|
|
|
|
ir::Node *node) const {
|
|
|
|
|
int op_dev_id = -1;
|
|
|
|
|
if (node->Op()->Type() == "send") {
|
|
|
|
|
// TODO(paddle-dev): getting the first var is not safe.
|
|
|
|
@ -824,6 +849,7 @@ void MultiDevSSAGraphBuilder::CreateRPCOp(ir::Graph *result,
|
|
|
|
|
CreateOpOutput(result, op_handle, new_node, p, outvar_dev_id);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
return op_dev_id;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool MultiDevSSAGraphBuilder::IsScaleLossOp(ir::Node *node) const {
|
|
|
|
|