|
|
|
@ -431,7 +431,9 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::ApplyImpl(
|
|
|
|
|
CreateReduceOp(&result, g_name, cur_device_id);
|
|
|
|
|
graph->Get<ShardedVarDevice>(kShardedVarDevice)
|
|
|
|
|
.emplace(g_name, cur_device_id);
|
|
|
|
|
bcast_var_name_set[cur_device_id].emplace(p_name);
|
|
|
|
|
if (!is_dist_train) {
|
|
|
|
|
bcast_var_name_set[cur_device_id].emplace(p_name);
|
|
|
|
|
}
|
|
|
|
|
break;
|
|
|
|
|
case BuildStrategy::ReduceStrategy::kAllReduce:
|
|
|
|
|
if (IsSparseGradient(g_name)) {
|
|
|
|
@ -461,7 +463,7 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::ApplyImpl(
|
|
|
|
|
if ((use_gpu &&
|
|
|
|
|
strategy_.reduce_ == BuildStrategy::ReduceStrategy::kReduce) ||
|
|
|
|
|
is_dist_train) {
|
|
|
|
|
// Insert BCast Ops
|
|
|
|
|
// allways broadcast receieved parameters for distributed training
|
|
|
|
|
for (size_t dev_id = 0; dev_id < bcast_var_name_set.size(); ++dev_id) {
|
|
|
|
|
auto &to_bcast_set = bcast_var_name_set[dev_id];
|
|
|
|
|
for (auto &bcast_name : to_bcast_set) {
|
|
|
|
|