|
|
|
@ -107,6 +107,10 @@ bool MultiDevSSAGraphBuilder::IsDistTrainOp(const OpDesc &op,
|
|
|
|
|
|
|
|
|
|
std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
|
|
|
|
|
const ProgramDesc &program) const {
|
|
|
|
|
std::unordered_map<std::string, proto::VarType::Type> var_types;
|
|
|
|
|
for (auto *var : program.Block(0).AllVars()) {
|
|
|
|
|
var_types[var->Name()] = var->GetType();
|
|
|
|
|
}
|
|
|
|
|
auto graph = new SSAGraph();
|
|
|
|
|
SSAGraph &result = *graph;
|
|
|
|
|
std::unordered_set<std::string> og_has_been_broadcast;
|
|
|
|
@ -116,7 +120,7 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
|
|
|
|
|
std::unordered_map<std::string, std::vector<std::unique_ptr<VarHandle>>>>(
|
|
|
|
|
places_.size());
|
|
|
|
|
|
|
|
|
|
size_t cur_update_sparse_gp_dev_id = 0;
|
|
|
|
|
size_t cur_dev_id = 0;
|
|
|
|
|
std::vector<std::unordered_set<std::string>> sparse_var_name_on_devices;
|
|
|
|
|
std::vector<std::unordered_set<std::string>> bcast_sparse_var_name_set;
|
|
|
|
|
|
|
|
|
@ -156,14 +160,12 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
|
|
|
|
|
// broadcast, and each gradient is only broadcast once.
|
|
|
|
|
for (auto &og : op->OutputArgumentNames()) {
|
|
|
|
|
if (IsParameterGradientOnce(og, &og_has_been_broadcast)) {
|
|
|
|
|
if (IsSparseGradient(og)) {
|
|
|
|
|
CreateReduceOp(&result, cur_update_sparse_gp_dev_id, og);
|
|
|
|
|
sparse_var_name_on_devices[cur_update_sparse_gp_dev_id].emplace(
|
|
|
|
|
og);
|
|
|
|
|
bcast_sparse_var_name_set[cur_update_sparse_gp_dev_id].emplace(
|
|
|
|
|
if (IsSparseGradient(var_types, og)) {
|
|
|
|
|
CreateReduceOp(&result, cur_dev_id, og);
|
|
|
|
|
sparse_var_name_on_devices[cur_dev_id].emplace(og);
|
|
|
|
|
bcast_sparse_var_name_set[cur_dev_id].emplace(
|
|
|
|
|
og.substr(0, og.size() - strlen(kGradVarSuffix)));
|
|
|
|
|
cur_update_sparse_gp_dev_id =
|
|
|
|
|
(cur_update_sparse_gp_dev_id + 1) % places_.size();
|
|
|
|
|
cur_dev_id = (cur_dev_id + 1) % places_.size();
|
|
|
|
|
} else {
|
|
|
|
|
InsertNCCLAllReduceOp(&result, og);
|
|
|
|
|
}
|
|
|
|
@ -201,10 +203,14 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
|
|
|
|
|
return std::unique_ptr<SSAGraph>(graph);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool MultiDevSSAGraphBuilder::IsSparseGradient(const std::string &og) const {
|
|
|
|
|
auto og_var = local_scopes_[0]->FindVar(og);
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(og_var);
|
|
|
|
|
return og_var->IsType<SelectedRows>();
|
|
|
|
|
bool MultiDevSSAGraphBuilder::IsSparseGradient(
|
|
|
|
|
const std::unordered_map<std::string, proto::VarType::Type> &var_types,
|
|
|
|
|
const std::string &og) const {
|
|
|
|
|
PADDLE_ENFORCE(var_types.count(og) != 0);
|
|
|
|
|
if (var_types.at(og) == proto::VarType::SELECTED_ROWS) {
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
int MultiDevSSAGraphBuilder::GetOpDeviceID(
|
|
|
|
|