|
|
|
@ -133,15 +133,15 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
|
|
|
|
|
void AppendMultiDevPass(const BuildStrategy &strategy) {
|
|
|
|
|
ir::Pass *multi_devices_pass;
|
|
|
|
|
if (strategy_.is_distribution_) {
|
|
|
|
|
VLOG(3) << "dist train mode";
|
|
|
|
|
VLOG(3) << "multi device dist train mode";
|
|
|
|
|
multi_devices_pass = AppendPass("dist_multi_devices_pass").get();
|
|
|
|
|
} else {
|
|
|
|
|
if (strategy.reduce_ == BuildStrategy::ReduceStrategy::kAllReduce) {
|
|
|
|
|
VLOG(3) << "allreduce mode";
|
|
|
|
|
VLOG(3) << "multi device allreduce mode";
|
|
|
|
|
multi_devices_pass =
|
|
|
|
|
AppendPass("allreduce_mode_multi_devices_pass").get();
|
|
|
|
|
} else if (strategy.reduce_ == BuildStrategy::ReduceStrategy::kReduce) {
|
|
|
|
|
VLOG(3) << "reduce mode";
|
|
|
|
|
VLOG(3) << "multi device reduce mode";
|
|
|
|
|
multi_devices_pass = AppendPass("reduce_mode_multi_devices_pass").get();
|
|
|
|
|
} else {
|
|
|
|
|
PADDLE_THROW("Unknown reduce strategy.");
|
|
|
|
|