|
|
|
@ -163,27 +163,34 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
|
|
|
|
|
if (static_cast<bool>(boost::get<int>(op->GetAttr(
|
|
|
|
|
OpProtoAndCheckerMaker::OpRoleAttrName())) &
|
|
|
|
|
static_cast<int>(OpRole::kBackward))) {
|
|
|
|
|
auto backward_vars = boost::get<std::vector<std::string>>(
|
|
|
|
|
op->GetAttrOrDefault(OpProtoAndCheckerMaker::OpRoleVarAttrName(),
|
|
|
|
|
std::vector<std::string>()));
|
|
|
|
|
for (auto &og : backward_vars) {
|
|
|
|
|
switch (strategy_.reduce_) {
|
|
|
|
|
case BuildStrategy::ReduceStrategy::kReduce:
|
|
|
|
|
CreateReduceOp(&result, og, cur_device_id);
|
|
|
|
|
var_name_on_devices[cur_device_id].emplace(og);
|
|
|
|
|
bcast_var_name_set[cur_device_id].emplace(
|
|
|
|
|
og.substr(0, og.size() - strlen(kGradVarSuffix)));
|
|
|
|
|
cur_device_id = (cur_device_id + 1) % places_.size();
|
|
|
|
|
break;
|
|
|
|
|
case BuildStrategy::ReduceStrategy::kAllReduce:
|
|
|
|
|
if (IsSparseGradient(var_types, og)) {
|
|
|
|
|
CreateReduceOp(&result, og, 0);
|
|
|
|
|
CreateBroadcastOp(&result, og, 0);
|
|
|
|
|
} else {
|
|
|
|
|
InsertNCCLAllReduceOp(&result, og);
|
|
|
|
|
}
|
|
|
|
|
break;
|
|
|
|
|
try {
|
|
|
|
|
auto backward_vars =
|
|
|
|
|
boost::get<std::vector<std::string>>(op->GetNullableAttr(
|
|
|
|
|
OpProtoAndCheckerMaker::OpRoleVarAttrName()));
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_EQ(backward_vars.size() % 2, 0);
|
|
|
|
|
|
|
|
|
|
for (size_t i = 0; i < backward_vars.size(); ++i) {
|
|
|
|
|
auto &p_name = backward_vars[i];
|
|
|
|
|
auto &g_name = backward_vars[i + 1];
|
|
|
|
|
switch (strategy_.reduce_) {
|
|
|
|
|
case BuildStrategy::ReduceStrategy::kReduce:
|
|
|
|
|
CreateReduceOp(&result, g_name, cur_device_id);
|
|
|
|
|
var_name_on_devices[cur_device_id].emplace(g_name);
|
|
|
|
|
bcast_var_name_set[cur_device_id].emplace(p_name);
|
|
|
|
|
cur_device_id = (cur_device_id + 1) % places_.size();
|
|
|
|
|
break;
|
|
|
|
|
case BuildStrategy::ReduceStrategy::kAllReduce:
|
|
|
|
|
if (IsSparseGradient(var_types, g_name)) {
|
|
|
|
|
CreateReduceOp(&result, g_name, 0);
|
|
|
|
|
CreateBroadcastOp(&result, g_name, 0);
|
|
|
|
|
} else {
|
|
|
|
|
InsertNCCLAllReduceOp(&result, g_name);
|
|
|
|
|
}
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
} catch (boost::bad_get e) {
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|