|
|
|
@ -57,6 +57,7 @@ MultiDevSSAGraphBuilder::MultiDevSSAGraphBuilder(
|
|
|
|
|
for (auto &p : params) {
|
|
|
|
|
grad_names_.insert(GradVarName(p));
|
|
|
|
|
}
|
|
|
|
|
balance_vars_.resize(places_.size(), 0);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void MultiDevSSAGraphBuilder::CreateOpHandleIOs(SSAGraph *result,
|
|
|
|
@ -140,11 +141,30 @@ bool MultiDevSSAGraphBuilder::IsDistTrainOp(
|
|
|
|
|
checker(op.InputArgumentNames(), recv_vars);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
size_t MultiDevSSAGraphBuilder::GetAppropriateDeviceID(
|
|
|
|
|
const std::vector<std::string> &var_names) const {
|
|
|
|
|
int64_t numel_sum = 0;
|
|
|
|
|
for (auto var_name : var_names) {
|
|
|
|
|
auto var_desc = all_vars_.at(var_name);
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(var_desc);
|
|
|
|
|
auto dim = framework::make_ddim(var_desc->GetShape());
|
|
|
|
|
int64_t numel = framework::product(dim);
|
|
|
|
|
PADDLE_ENFORCE_GT(numel, 0);
|
|
|
|
|
numel_sum += numel;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto smallest =
|
|
|
|
|
std::min_element(std::begin(balance_vars_), std::end(balance_vars_));
|
|
|
|
|
size_t dev_id =
|
|
|
|
|
static_cast<size_t>(std::distance(std::begin(balance_vars_), smallest));
|
|
|
|
|
balance_vars_[dev_id] += numel_sum;
|
|
|
|
|
return dev_id;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
|
|
|
|
|
const ProgramDesc &program) const {
|
|
|
|
|
std::unordered_map<std::string, VarDesc *> all_vars;
|
|
|
|
|
for (auto *var : program.Block(0).AllVars()) {
|
|
|
|
|
all_vars[var->Name()] = var;
|
|
|
|
|
all_vars_.emplace(var->Name(), var);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto graph = new SSAGraph();
|
|
|
|
@ -161,35 +181,16 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
|
|
|
|
|
auto send_vars = FindDistTrainSendVars(program);
|
|
|
|
|
auto recv_vars = FindDistTrainRecvVars(program);
|
|
|
|
|
|
|
|
|
|
std::vector<std::unordered_set<std::string>> var_name_on_devices;
|
|
|
|
|
std::vector<std::unordered_set<std::string>> bcast_var_name_set;
|
|
|
|
|
var_name_on_devices.resize(places_.size());
|
|
|
|
|
bcast_var_name_set.resize(places_.size());
|
|
|
|
|
|
|
|
|
|
size_t cur_device_id = 0;
|
|
|
|
|
std::vector<int64_t> balance_grads(places_.size(), 0);
|
|
|
|
|
|
|
|
|
|
auto get_appropriate_dev = [&](std::string &g_name) -> size_t {
|
|
|
|
|
auto var_desc = all_vars.at(g_name);
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(var_desc);
|
|
|
|
|
auto dim = framework::make_ddim(var_desc->GetShape());
|
|
|
|
|
int64_t numel = framework::product(dim);
|
|
|
|
|
PADDLE_ENFORCE_GE(numel, 0);
|
|
|
|
|
auto smallest =
|
|
|
|
|
std::min_element(std::begin(balance_grads), std::end(balance_grads));
|
|
|
|
|
size_t dev_id =
|
|
|
|
|
static_cast<size_t>(std::distance(std::begin(balance_grads), smallest));
|
|
|
|
|
balance_grads[dev_id] += numel;
|
|
|
|
|
return dev_id;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
bool is_forwarding = true;
|
|
|
|
|
|
|
|
|
|
for (auto *op : program.Block(0).AllOps()) {
|
|
|
|
|
if (boost::get<int>(
|
|
|
|
|
op->GetAttr(OpProtoAndCheckerMaker::OpRoleAttrName())) ==
|
|
|
|
|
static_cast<int>(OpRole::kRPC)) {
|
|
|
|
|
// append rpc op if program is distributed trainer main program.
|
|
|
|
|
// always use the first device
|
|
|
|
|
CreateRPCOp(&result, *op);
|
|
|
|
|
} else if (IsDistTrainOp(*op, send_vars, recv_vars)) {
|
|
|
|
|
CreateDistTrainOp(&result, *op);
|
|
|
|
@ -201,13 +202,13 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
|
|
|
|
|
}
|
|
|
|
|
is_forwarding = false;
|
|
|
|
|
} else {
|
|
|
|
|
int op_dev_id = GetOpDeviceID(var_name_on_devices, *op);
|
|
|
|
|
int op_dev_id = GetOpDeviceID(*op);
|
|
|
|
|
if (op_dev_id == -1) { // var on all device
|
|
|
|
|
CreateComputationalOps(&result, *op, places_.size());
|
|
|
|
|
} else {
|
|
|
|
|
CreateComputationalOp(&result, *op, op_dev_id);
|
|
|
|
|
for (auto &var_name : op->OutputArgumentNames()) {
|
|
|
|
|
var_name_on_devices[op_dev_id].emplace(var_name);
|
|
|
|
|
var_name_on_devices_.emplace(var_name, op_dev_id);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
if (!is_forwarding && places_.size() > 1) {
|
|
|
|
@ -230,13 +231,13 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
|
|
|
|
|
|
|
|
|
|
switch (strategy_.reduce_) {
|
|
|
|
|
case BuildStrategy::ReduceStrategy::kReduce:
|
|
|
|
|
cur_device_id = get_appropriate_dev(g_name);
|
|
|
|
|
cur_device_id = GetAppropriateDeviceID({g_name});
|
|
|
|
|
CreateReduceOp(&result, g_name, cur_device_id);
|
|
|
|
|
var_name_on_devices[cur_device_id].emplace(g_name);
|
|
|
|
|
var_name_on_devices_.emplace(g_name, cur_device_id);
|
|
|
|
|
bcast_var_name_set[cur_device_id].emplace(p_name);
|
|
|
|
|
break;
|
|
|
|
|
case BuildStrategy::ReduceStrategy::kAllReduce:
|
|
|
|
|
if (IsSparseGradient(all_vars, g_name)) {
|
|
|
|
|
if (IsSparseGradient(g_name)) {
|
|
|
|
|
CreateReduceOp(&result, g_name, 0);
|
|
|
|
|
CreateBroadcastOp(&result, g_name, 0);
|
|
|
|
|
} else {
|
|
|
|
@ -273,11 +274,9 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
|
|
|
|
|
return std::unique_ptr<SSAGraph>(graph);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool MultiDevSSAGraphBuilder::IsSparseGradient(
|
|
|
|
|
const std::unordered_map<std::string, VarDesc *> &all_vars,
|
|
|
|
|
const std::string &og) const {
|
|
|
|
|
PADDLE_ENFORCE(all_vars.count(og) != 0);
|
|
|
|
|
if (all_vars.at(og)->GetType() == proto::VarType::SELECTED_ROWS) {
|
|
|
|
|
bool MultiDevSSAGraphBuilder::IsSparseGradient(const std::string &og) const {
|
|
|
|
|
PADDLE_ENFORCE(all_vars_.count(og) != 0);
|
|
|
|
|
if (all_vars_.at(og)->GetType() == proto::VarType::SELECTED_ROWS) {
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
return false;
|
|
|
|
@ -363,24 +362,23 @@ bool MultiDevSSAGraphBuilder::IsParameterGradientOnce(
|
|
|
|
|
return is_pg_once;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
int MultiDevSSAGraphBuilder::GetOpDeviceID(
|
|
|
|
|
const std::vector<std::unordered_set<std::string>> &var_name_on_devices,
|
|
|
|
|
const OpDesc &op) const {
|
|
|
|
|
int MultiDevSSAGraphBuilder::GetOpDeviceID(const OpDesc &op) const {
|
|
|
|
|
if (strategy_.reduce_ != BuildStrategy::ReduceStrategy::kReduce) {
|
|
|
|
|
return -1;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
int var_dev_id = -1;
|
|
|
|
|
for (auto &var_name : op.InputArgumentNames()) {
|
|
|
|
|
if (var_dev_id != -1) break;
|
|
|
|
|
for (size_t i = 0; i < var_name_on_devices.size(); ++i) {
|
|
|
|
|
if (var_name_on_devices[i].count(var_name)) {
|
|
|
|
|
var_dev_id = static_cast<int>(i);
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
for (auto &varname : op.InputArgumentNames()) {
|
|
|
|
|
int dev_id = GetVarDeviceID(varname);
|
|
|
|
|
if (dev_id != -1) {
|
|
|
|
|
return dev_id;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
return var_dev_id;
|
|
|
|
|
return -1;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
int MultiDevSSAGraphBuilder::GetVarDeviceID(const std::string &varname) const {
|
|
|
|
|
auto got = var_name_on_devices_.find(varname);
|
|
|
|
|
return got == var_name_on_devices_.end() ? -1 : got->second;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void MultiDevSSAGraphBuilder::CreateScaleLossGradOp(SSAGraph *result) const {
|
|
|
|
@ -463,7 +461,30 @@ void MultiDevSSAGraphBuilder::ConnectOp(SSAGraph *result, OpHandleBase *op,
|
|
|
|
|
|
|
|
|
|
void MultiDevSSAGraphBuilder::CreateDistTrainOp(SSAGraph *result,
|
|
|
|
|
const OpDesc &op) const {
|
|
|
|
|
CreateComputationalOp(result, op, 0);
|
|
|
|
|
int op_dev_id = -1;
|
|
|
|
|
if (op.Type() == "split_byref") {
|
|
|
|
|
op_dev_id = GetVarDeviceID(op.InputArgumentNames()[0]);
|
|
|
|
|
if (strategy_.reduce_ == BuildStrategy::ReduceStrategy::kAllReduce) {
|
|
|
|
|
op_dev_id = GetAppropriateDeviceID(op.InputArgumentNames());
|
|
|
|
|
for (auto &varname : op.InputArgumentNames()) {
|
|
|
|
|
var_name_on_devices_.emplace(varname, op_dev_id);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
for (auto &varname : op.OutputArgumentNames()) {
|
|
|
|
|
var_name_on_devices_.emplace(varname, op_dev_id);
|
|
|
|
|
}
|
|
|
|
|
} else if (op.Type() == "concat") {
|
|
|
|
|
op_dev_id = GetVarDeviceID(op.InputArgumentNames()[0]);
|
|
|
|
|
} else {
|
|
|
|
|
PADDLE_ENFORCE(
|
|
|
|
|
"the distribute training related op should be in [split_byref, "
|
|
|
|
|
"concat].");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE(op_dev_id != -1,
|
|
|
|
|
"can not find right place for distributed op: %s", op.Type());
|
|
|
|
|
|
|
|
|
|
CreateComputationalOp(result, op, op_dev_id);
|
|
|
|
|
if (op.Type() == "concat") {
|
|
|
|
|
ConnectOp(result, result->ops_.back().get(), "fetch_barrier");
|
|
|
|
|
}
|
|
|
|
@ -471,8 +492,34 @@ void MultiDevSSAGraphBuilder::CreateDistTrainOp(SSAGraph *result,
|
|
|
|
|
|
|
|
|
|
void MultiDevSSAGraphBuilder::CreateRPCOp(SSAGraph *result,
|
|
|
|
|
const OpDesc &op) const {
|
|
|
|
|
result->ops_.emplace_back(
|
|
|
|
|
new RPCOpHandle(op, local_scopes_[0], op.Type(), places_[0]));
|
|
|
|
|
int op_dev_id = -1;
|
|
|
|
|
if (op.Type() == "send") {
|
|
|
|
|
op_dev_id = GetVarDeviceID(op.InputArgumentNames()[0]);
|
|
|
|
|
// the variable name which contains .block means it was splited by
|
|
|
|
|
// split_byref op
|
|
|
|
|
// so that we can balance the variable blocks to all the pserver instances.
|
|
|
|
|
if (strategy_.reduce_ == BuildStrategy::ReduceStrategy::kAllReduce &&
|
|
|
|
|
op.InputArgumentNames()[0].find(".block") == std::string::npos) {
|
|
|
|
|
op_dev_id = GetAppropriateDeviceID(op.InputArgumentNames());
|
|
|
|
|
for (auto &varname : op.InputArgumentNames()) {
|
|
|
|
|
var_name_on_devices_.emplace(varname, op_dev_id);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
} else if (op.Type() == "recv") {
|
|
|
|
|
op_dev_id = GetAppropriateDeviceID(op.OutputArgumentNames());
|
|
|
|
|
for (auto &varname : op.OutputArgumentNames()) {
|
|
|
|
|
var_name_on_devices_.emplace(varname, op_dev_id);
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
// send_barrier and fetch_barrier op can be scheduled on device 0
|
|
|
|
|
op_dev_id = 0;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE(op_dev_id != -1, "can not find the right place for rpc op: %s",
|
|
|
|
|
op.Type());
|
|
|
|
|
|
|
|
|
|
result->ops_.emplace_back(new RPCOpHandle(op, local_scopes_[op_dev_id],
|
|
|
|
|
op.Type(), places_[op_dev_id]));
|
|
|
|
|
|
|
|
|
|
if (op.Type() == "send_barrier") {
|
|
|
|
|
ConnectOp(result, result->ops_.back().get(), "send");
|
|
|
|
@ -488,9 +535,7 @@ void MultiDevSSAGraphBuilder::CreateRPCOp(SSAGraph *result,
|
|
|
|
|
"send, send_barrier. recv, fetch_barrier]");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// TODO(Yancey1989): schedule rpc op on different place may
|
|
|
|
|
// increate throughput
|
|
|
|
|
CreateOpHandleIOs(result, op, 0);
|
|
|
|
|
CreateOpHandleIOs(result, op, op_dev_id);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool MultiDevSSAGraphBuilder::IsScaleLossOp(const OpDesc &op) const {
|
|
|
|
|