|
|
|
@ -460,10 +460,14 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::ApplyImpl(
|
|
|
|
|
use_gpu = nccl_ctxs_ != nullptr;
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
// Insert broadcast operators principle:
|
|
|
|
|
// 1. Broadcast optimized parameters in Reduce strategy;
|
|
|
|
|
// 2. No need broadcast optimized parameters in AllReduce strategy because of
|
|
|
|
|
// the optimization sub-graph would be run on every GPU;
|
|
|
|
|
// 3. Allways broadcast received parameters in Distribute Training.
|
|
|
|
|
if ((use_gpu &&
|
|
|
|
|
strategy_.reduce_ == BuildStrategy::ReduceStrategy::kReduce) ||
|
|
|
|
|
is_dist_train) {
|
|
|
|
|
// allways broadcast receieved parameters for distributed training
|
|
|
|
|
for (size_t dev_id = 0; dev_id < bcast_var_name_set.size(); ++dev_id) {
|
|
|
|
|
auto &to_bcast_set = bcast_var_name_set[dev_id];
|
|
|
|
|
for (auto &bcast_name : to_bcast_set) {
|
|
|
|
|