|
|
|
@ -207,53 +207,56 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
|
|
|
|
|
is_forwarding = false;
|
|
|
|
|
} else {
|
|
|
|
|
int op_dev_id = GetOpDeviceID(*op);
|
|
|
|
|
if (op_dev_id == -1) { // var on all device
|
|
|
|
|
CreateComputationalOps(&result, *op, places_.size());
|
|
|
|
|
} else {
|
|
|
|
|
if (op_dev_id != -1) { // This op only runs on one specific device.
|
|
|
|
|
CreateComputationalOp(&result, *op, op_dev_id);
|
|
|
|
|
for (auto &var_name : op->OutputArgumentNames()) {
|
|
|
|
|
var_name_on_devices_.emplace(var_name, op_dev_id);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
if (!is_forwarding && places_.size() > 1) {
|
|
|
|
|
// Currently, we assume that once gradient is generated, it can be
|
|
|
|
|
// broadcast, and each gradient is only broadcast once.
|
|
|
|
|
if (static_cast<bool>(boost::get<int>(op->GetAttr(
|
|
|
|
|
OpProtoAndCheckerMaker::OpRoleAttrName())) &
|
|
|
|
|
static_cast<int>(OpRole::kBackward))) {
|
|
|
|
|
try {
|
|
|
|
|
auto backward_vars =
|
|
|
|
|
boost::get<std::vector<std::string>>(op->GetNullableAttr(
|
|
|
|
|
OpProtoAndCheckerMaker::OpRoleVarAttrName()));
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_EQ(backward_vars.size() % 2, 0);
|
|
|
|
|
|
|
|
|
|
for (size_t i = 0; i < backward_vars.size(); i += 2) {
|
|
|
|
|
auto &p_name = backward_vars[i];
|
|
|
|
|
auto &g_name = backward_vars[i + 1];
|
|
|
|
|
VLOG(10) << "Bcast " << g_name << " for parameter " << p_name;
|
|
|
|
|
|
|
|
|
|
switch (strategy_.reduce_) {
|
|
|
|
|
case BuildStrategy::ReduceStrategy::kReduce:
|
|
|
|
|
cur_device_id = GetAppropriateDeviceID({g_name});
|
|
|
|
|
CreateReduceOp(&result, g_name, cur_device_id);
|
|
|
|
|
var_name_on_devices_.emplace(g_name, cur_device_id);
|
|
|
|
|
bcast_var_name_set[cur_device_id].emplace(p_name);
|
|
|
|
|
break;
|
|
|
|
|
case BuildStrategy::ReduceStrategy::kAllReduce:
|
|
|
|
|
if (IsSparseGradient(g_name)) {
|
|
|
|
|
CreateReduceOp(&result, g_name, 0);
|
|
|
|
|
CreateBroadcastOp(&result, g_name, 0);
|
|
|
|
|
} else {
|
|
|
|
|
InsertAllReduceOp(&result, g_name);
|
|
|
|
|
}
|
|
|
|
|
break;
|
|
|
|
|
default:
|
|
|
|
|
LOG(FATAL) << "Unknown reduce strategy ";
|
|
|
|
|
break;
|
|
|
|
|
} else {
|
|
|
|
|
// This op runs on all devices, and its output may have parameter's
|
|
|
|
|
// gradients.
|
|
|
|
|
CreateComputationalOps(&result, *op, places_.size());
|
|
|
|
|
|
|
|
|
|
if (!is_forwarding && places_.size() > 1) {
|
|
|
|
|
// Currently, we assume that once gradient is generated, it can be
|
|
|
|
|
// broadcast, and each gradient is only broadcast once.
|
|
|
|
|
if (static_cast<bool>(boost::get<int>(op->GetAttr(
|
|
|
|
|
OpProtoAndCheckerMaker::OpRoleAttrName())) &
|
|
|
|
|
static_cast<int>(OpRole::kBackward))) {
|
|
|
|
|
try {
|
|
|
|
|
auto backward_vars =
|
|
|
|
|
boost::get<std::vector<std::string>>(op->GetNullableAttr(
|
|
|
|
|
OpProtoAndCheckerMaker::OpRoleVarAttrName()));
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_EQ(backward_vars.size() % 2, 0);
|
|
|
|
|
|
|
|
|
|
for (size_t i = 0; i < backward_vars.size(); i += 2) {
|
|
|
|
|
auto &p_name = backward_vars[i];
|
|
|
|
|
auto &g_name = backward_vars[i + 1];
|
|
|
|
|
VLOG(10) << "Bcast " << g_name << " for parameter " << p_name;
|
|
|
|
|
|
|
|
|
|
switch (strategy_.reduce_) {
|
|
|
|
|
case BuildStrategy::ReduceStrategy::kReduce:
|
|
|
|
|
cur_device_id = GetAppropriateDeviceID({g_name});
|
|
|
|
|
CreateReduceOp(&result, g_name, cur_device_id);
|
|
|
|
|
var_name_on_devices_.emplace(g_name, cur_device_id);
|
|
|
|
|
bcast_var_name_set[cur_device_id].emplace(p_name);
|
|
|
|
|
break;
|
|
|
|
|
case BuildStrategy::ReduceStrategy::kAllReduce:
|
|
|
|
|
if (IsSparseGradient(g_name)) {
|
|
|
|
|
CreateReduceOp(&result, g_name, 0);
|
|
|
|
|
CreateBroadcastOp(&result, g_name, 0);
|
|
|
|
|
} else {
|
|
|
|
|
InsertAllReduceOp(&result, g_name);
|
|
|
|
|
}
|
|
|
|
|
break;
|
|
|
|
|
default:
|
|
|
|
|
LOG(FATAL) << "Unknown reduce strategy ";
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
} catch (boost::bad_get e) {
|
|
|
|
|
}
|
|
|
|
|
} catch (boost::bad_get e) {
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|