|
|
|
@ -34,6 +34,8 @@ namespace details {
|
|
|
|
|
static inline bool SeqOnlyAllReduceOps(const BuildStrategy &strategy) {
|
|
|
|
|
// Should fix the allreduce op order if scheduling
|
|
|
|
|
// them in multiple threads or processes to avoid hang.
|
|
|
|
|
// NOTE: ParallelExecutor would execute this pass on each graph, so
|
|
|
|
|
// don't need to append it here.
|
|
|
|
|
return (!strategy.enable_sequential_execution_ &&
|
|
|
|
|
strategy.num_trainers_ > 1) &&
|
|
|
|
|
!strategy.enable_parallel_graph_;
|
|
|
|
@ -118,7 +120,7 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Verify that the graph is correct for multi-device executor.
|
|
|
|
|
auto multi_devices_pass = AppendPass("multi_devices_check_pass");
|
|
|
|
|
AppendPass("multi_devices_check_pass");
|
|
|
|
|
|
|
|
|
|
if (SeqOnlyAllReduceOps(strategy)) {
|
|
|
|
|
AppendPass("all_reduce_deps_pass");
|
|
|
|
|