|
|
|
@ -201,7 +201,7 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilderBase::ApplyImpl(
|
|
|
|
|
auto &g_name = backward_vars[i + 1];
|
|
|
|
|
VLOG(10) << "Bcast " << g_name << " for parameter " << p_name;
|
|
|
|
|
|
|
|
|
|
InsertCollectiveOp(&result, node, p_name, g_name);
|
|
|
|
|
InsertCollectiveOp(&result, p_name, g_name);
|
|
|
|
|
}
|
|
|
|
|
} catch (boost::bad_get e) {
|
|
|
|
|
}
|
|
|
|
@ -386,7 +386,7 @@ void MultiDevSSAGraphBuilderBase::CreateComputationalOp(ir::Graph *result,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void MultiDevSSAGraphBuilderBase::CreateAllReduceOp(
|
|
|
|
|
ir::Graph *result, ir::Node *node, const std::string &og) const {
|
|
|
|
|
ir::Graph *result, const std::string &og) const {
|
|
|
|
|
OpHandleBase *op_handle = nullptr;
|
|
|
|
|
|
|
|
|
|
auto append_allreduce_op = [&](
|
|
|
|
@ -510,13 +510,13 @@ bool MultiDevSSAGraphBuilderBase::IsSparseGradient(
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void AllReduceSSAGraphBuilder::InsertCollectiveOp(
|
|
|
|
|
ir::Graph *result, ir::Node *node, const std::string &p_name,
|
|
|
|
|
ir::Graph *result, const std::string &p_name,
|
|
|
|
|
const std::string &g_name) const {
|
|
|
|
|
if (IsSparseGradient(g_name)) {
|
|
|
|
|
CreateReduceOp(result, g_name, 0);
|
|
|
|
|
CreateBroadcastOp(result, g_name, 0);
|
|
|
|
|
} else {
|
|
|
|
|
CreateAllReduceOp(result, node, g_name);
|
|
|
|
|
CreateAllReduceOp(result, g_name);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -589,7 +589,7 @@ void ReduceSSAGraphBuilder::ResetState() const {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void ReduceSSAGraphBuilder::InsertCollectiveOp(
|
|
|
|
|
ir::Graph *result, ir::Node *node, const std::string &p_name,
|
|
|
|
|
ir::Graph *result, const std::string &p_name,
|
|
|
|
|
const std::string &g_name) const {
|
|
|
|
|
size_t cur_device_id = GetAppropriateDeviceID({g_name});
|
|
|
|
|
CreateReduceOp(result, g_name, cur_device_id);
|
|
|
|
@ -909,7 +909,7 @@ int DistSSAGraphBuilder::CreateDistTrainOp(ir::Graph *result,
|
|
|
|
|
return op_dev_id;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void DistSSAGraphBuilder::InsertCollectiveOp(ir::Graph *result, ir::Node *node,
|
|
|
|
|
void DistSSAGraphBuilder::InsertCollectiveOp(ir::Graph *result,
|
|
|
|
|
const std::string &p_name,
|
|
|
|
|
const std::string &g_name) const {
|
|
|
|
|
size_t cur_device_id = 0;
|
|
|
|
@ -924,7 +924,7 @@ void DistSSAGraphBuilder::InsertCollectiveOp(ir::Graph *result, ir::Node *node,
|
|
|
|
|
CreateReduceOp(result, g_name, 0);
|
|
|
|
|
CreateBroadcastOp(result, g_name, 0);
|
|
|
|
|
} else {
|
|
|
|
|
CreateAllReduceOp(result, node, g_name);
|
|
|
|
|
CreateAllReduceOp(result, g_name);
|
|
|
|
|
}
|
|
|
|
|
break;
|
|
|
|
|
default:
|
|
|
|
@ -975,8 +975,7 @@ static int MultiDevSSAGraphBuilderRegister(const std::string &builder_mode) {
|
|
|
|
|
.RequirePassAttr(paddle::framework::details::kPlaces) \
|
|
|
|
|
.RequirePassAttr(paddle::framework::details::kLocalScopes) \
|
|
|
|
|
.RequirePassAttr(paddle::framework::details::kStrategy) \
|
|
|
|
|
.RequirePassAttr(paddle::framework::details::kNRanks) \
|
|
|
|
|
.RequirePassAttr(paddle::framework::details::kEnablePG)
|
|
|
|
|
.RequirePassAttr(paddle::framework::details::kNRanks)
|
|
|
|
|
|
|
|
|
|
REGISTER_MULTI_DEVICES_PASS(reduce_mode_multi_devices_pass,
|
|
|
|
|
paddle::framework::details::ReduceSSAGraphBuilder);
|
|
|
|
|