|
|
|
@ -142,7 +142,6 @@ bool MultiDevSSAGraphBuilder::IsDistTrainOp(
|
|
|
|
|
|
|
|
|
|
std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
|
|
|
|
|
const ProgramDesc &program) const {
|
|
|
|
|
VLOG(3) << "Building ....";
|
|
|
|
|
std::unordered_map<std::string, VarDesc *> all_vars;
|
|
|
|
|
for (auto *var : program.Block(0).AllVars()) {
|
|
|
|
|
all_vars[var->Name()] = var;
|
|
|
|
@ -162,36 +161,32 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
|
|
|
|
|
auto send_vars = FindDistTrainSendVars(program);
|
|
|
|
|
auto recv_vars = FindDistTrainRecvVars(program);
|
|
|
|
|
|
|
|
|
|
std::vector<std::unordered_set<std::string>> var_name_on_devices;
|
|
|
|
|
std::vector<std::unordered_set<std::string>> bcast_var_name_set;
|
|
|
|
|
var_name_on_devices.resize(places_.size());
|
|
|
|
|
bcast_var_name_set.resize(places_.size());
|
|
|
|
|
|
|
|
|
|
size_t cur_device_id = 0;
|
|
|
|
|
std::vector<int64_t> balance_grads(places_.size(), 0);
|
|
|
|
|
|
|
|
|
|
auto get_appropriate_dev = [&](std::string &g_name) -> size_t {
|
|
|
|
|
auto var_desc = all_vars.at(g_name);
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(var_desc);
|
|
|
|
|
auto dim = framework::make_ddim(var_desc->GetShape());
|
|
|
|
|
int64_t numel = framework::product(dim);
|
|
|
|
|
PADDLE_ENFORCE_GE(numel, 0);
|
|
|
|
|
auto get_appropriate_dev = [&](std::vector<std::string> var_names) -> size_t {
|
|
|
|
|
int64_t numel_all = 0;
|
|
|
|
|
for (auto var_name : var_names) {
|
|
|
|
|
auto var_desc = all_vars.at(var_name);
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(var_desc);
|
|
|
|
|
auto dim = framework::make_ddim(var_desc->GetShape());
|
|
|
|
|
int64_t numel = framework::product(dim);
|
|
|
|
|
PADDLE_ENFORCE_GT(numel, 0);
|
|
|
|
|
numel_all += numel;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto smallest =
|
|
|
|
|
std::min_element(std::begin(balance_grads), std::end(balance_grads));
|
|
|
|
|
size_t dev_id =
|
|
|
|
|
static_cast<size_t>(std::distance(std::begin(balance_grads), smallest));
|
|
|
|
|
balance_grads[dev_id] += numel;
|
|
|
|
|
balance_grads[dev_id] += numel_all;
|
|
|
|
|
return dev_id;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
bool is_forwarding = true;
|
|
|
|
|
int rpc_op_device_id = 0;
|
|
|
|
|
auto schedule_rpc_op = [&]() -> void {
|
|
|
|
|
rpc_op_device_id++;
|
|
|
|
|
if (rpc_op_device_id >= static_cast<int>(places_.size())) {
|
|
|
|
|
rpc_op_device_id = 0;
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
for (auto *op : program.Block(0).AllOps()) {
|
|
|
|
|
if (boost::get<int>(
|
|
|
|
@ -200,37 +195,40 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
|
|
|
|
|
// append rpc op if program is distributed trainer main program.
|
|
|
|
|
// always use the first device
|
|
|
|
|
if (op->Type() == "send_vars") {
|
|
|
|
|
auto got = remote_vars_devices_.find(op->InputArgumentNames()[0]);
|
|
|
|
|
if (got == remote_vars_devices_.end()) {
|
|
|
|
|
schedule_rpc_op();
|
|
|
|
|
} else {
|
|
|
|
|
rpc_op_device_id = got->second;
|
|
|
|
|
int op_dev_id = GetVarDeviceID(op->InputArgumentNames()[0]);
|
|
|
|
|
if (op_dev_id == -1) {
|
|
|
|
|
op_dev_id = get_appropriate_dev(op->InputArgumentNames());
|
|
|
|
|
for (auto &varname : op->InputArgumentNames()) {
|
|
|
|
|
var_name_on_devices_.emplace(varname, op_dev_id);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
CreateRPCOp(&result, *op, rpc_op_device_id);
|
|
|
|
|
CreateRPCOp(&result, *op, op_dev_id);
|
|
|
|
|
} else if (op->Type() == "recv") {
|
|
|
|
|
schedule_rpc_op();
|
|
|
|
|
int op_dev_id = get_appropriate_dev(op->OutputArgumentNames());
|
|
|
|
|
for (auto &varname : op->OutputArgumentNames()) {
|
|
|
|
|
remote_vars_devices_.insert({varname, rpc_op_device_id});
|
|
|
|
|
var_name_on_devices_.emplace(varname, op_dev_id);
|
|
|
|
|
}
|
|
|
|
|
CreateRPCOp(&result, *op, rpc_op_device_id);
|
|
|
|
|
CreateRPCOp(&result, *op, op_dev_id);
|
|
|
|
|
} else {
|
|
|
|
|
// send_barrier and fetch_barrier op would run on device 0
|
|
|
|
|
CreateRPCOp(&result, *op, 0);
|
|
|
|
|
}
|
|
|
|
|
} else if (IsDistTrainOp(*op, send_vars, recv_vars)) {
|
|
|
|
|
if (op->Type() == "split_byref") {
|
|
|
|
|
schedule_rpc_op();
|
|
|
|
|
int op_dev_id = get_appropriate_dev(op->OutputArgumentNames());
|
|
|
|
|
for (auto &varname : op->OutputArgumentNames()) {
|
|
|
|
|
remote_vars_devices_.insert({varname, rpc_op_device_id});
|
|
|
|
|
var_name_on_devices_.emplace(varname, op_dev_id);
|
|
|
|
|
}
|
|
|
|
|
CreateDistTrainOp(&result, *op, rpc_op_device_id);
|
|
|
|
|
}
|
|
|
|
|
if (op->Type() == "concat") {
|
|
|
|
|
auto got = remote_vars_devices_.find(op->InputArgumentNames()[0]);
|
|
|
|
|
PADDLE_ENFORCE(got != remote_vars_devices_.end(),
|
|
|
|
|
CreateDistTrainOp(&result, *op, op_dev_id);
|
|
|
|
|
} else if (op->Type() == "concat") {
|
|
|
|
|
int op_dev_id = GetVarDeviceID(op->InputArgumentNames()[0]);
|
|
|
|
|
PADDLE_ENFORCE(op_dev_id != -1,
|
|
|
|
|
"can not find right place to concatenate received var.");
|
|
|
|
|
CreateDistTrainOp(&result, *op, got->second);
|
|
|
|
|
CreateDistTrainOp(&result, *op, op_dev_id);
|
|
|
|
|
} else {
|
|
|
|
|
CreateDistTrainOp(&result, *op, 0);
|
|
|
|
|
PADDLE_ENFORCE(
|
|
|
|
|
"the distribute training related op should be in [split_byref, "
|
|
|
|
|
"concat].");
|
|
|
|
|
}
|
|
|
|
|
} else if (IsScaleLossOp(*op)) {
|
|
|
|
|
// user can customize loss@grad if not use_default_grad_scale_
|
|
|
|
@ -240,13 +238,13 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
|
|
|
|
|
}
|
|
|
|
|
is_forwarding = false;
|
|
|
|
|
} else {
|
|
|
|
|
int op_dev_id = GetOpDeviceID(var_name_on_devices, *op);
|
|
|
|
|
int op_dev_id = GetOpDeviceID(*op);
|
|
|
|
|
if (op_dev_id == -1) { // var on all device
|
|
|
|
|
CreateComputationalOps(&result, *op, places_.size());
|
|
|
|
|
} else {
|
|
|
|
|
CreateComputationalOp(&result, *op, op_dev_id);
|
|
|
|
|
for (auto &var_name : op->OutputArgumentNames()) {
|
|
|
|
|
var_name_on_devices[op_dev_id].emplace(var_name);
|
|
|
|
|
var_name_on_devices_.emplace(var_name, op_dev_id);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
if (!is_forwarding && places_.size() > 1) {
|
|
|
|
@ -269,9 +267,9 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
|
|
|
|
|
|
|
|
|
|
switch (strategy_.reduce_) {
|
|
|
|
|
case BuildStrategy::ReduceStrategy::kReduce:
|
|
|
|
|
cur_device_id = get_appropriate_dev(g_name);
|
|
|
|
|
cur_device_id = get_appropriate_dev({g_name});
|
|
|
|
|
CreateReduceOp(&result, g_name, cur_device_id);
|
|
|
|
|
var_name_on_devices[cur_device_id].emplace(g_name);
|
|
|
|
|
var_name_on_devices_.emplace(g_name, cur_device_id);
|
|
|
|
|
bcast_var_name_set[cur_device_id].emplace(p_name);
|
|
|
|
|
break;
|
|
|
|
|
case BuildStrategy::ReduceStrategy::kAllReduce:
|
|
|
|
@ -402,24 +400,23 @@ bool MultiDevSSAGraphBuilder::IsParameterGradientOnce(
|
|
|
|
|
return is_pg_once;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
int MultiDevSSAGraphBuilder::GetOpDeviceID(
|
|
|
|
|
const std::vector<std::unordered_set<std::string>> &var_name_on_devices,
|
|
|
|
|
const OpDesc &op) const {
|
|
|
|
|
int MultiDevSSAGraphBuilder::GetOpDeviceID(const OpDesc &op) const {
|
|
|
|
|
if (strategy_.reduce_ != BuildStrategy::ReduceStrategy::kReduce) {
|
|
|
|
|
return -1;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
int var_dev_id = -1;
|
|
|
|
|
for (auto &var_name : op.InputArgumentNames()) {
|
|
|
|
|
if (var_dev_id != -1) break;
|
|
|
|
|
for (size_t i = 0; i < var_name_on_devices.size(); ++i) {
|
|
|
|
|
if (var_name_on_devices[i].count(var_name)) {
|
|
|
|
|
var_dev_id = static_cast<int>(i);
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
for (auto &varname : op.InputArgumentNames()) {
|
|
|
|
|
int dev_id = GetVarDeviceID(varname);
|
|
|
|
|
if (dev_id != -1) {
|
|
|
|
|
return dev_id;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
return var_dev_id;
|
|
|
|
|
return -1;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
int MultiDevSSAGraphBuilder::GetVarDeviceID(const std::string &varname) const {
|
|
|
|
|
auto got = var_name_on_devices_.find(varname);
|
|
|
|
|
return got == var_name_on_devices_.end() ? -1 : got->second;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void MultiDevSSAGraphBuilder::CreateScaleLossGradOp(SSAGraph *result) const {
|
|
|
|
|