diff --git a/python/paddle/fluid/compiler.py b/python/paddle/fluid/compiler.py index 7400f45e05..f43fcb2d9f 100644 --- a/python/paddle/fluid/compiler.py +++ b/python/paddle/fluid/compiler.py @@ -356,6 +356,16 @@ class CompiledProgram(object): if self._build_strategy.sync_batch_norm: self._build_strategy.enable_sequential_execution = True + if self._program is not None and self._program._enable_dgc: + assert use_cuda, "DGC only used under cuda" + assert self._build_strategy.num_trainers * len( + places) > 1, "DGC is not useful for single card training" + assert self._build_strategy.reduce_strategy == BuildStrategy.ReduceStrategy.AllReduce, "DGC \ + only used for AllReduce BuildStrategy" + + # DGC doesn't support fuse for now, close fuse. + self._build_strategy.fuse_all_reduce_ops = False + self._persistable_vars = [] for node in self._graph.nodes(): if node.is_var() and node.var() is not None and node.var().persistable() and \ diff --git a/python/paddle/fluid/parallel_executor.py b/python/paddle/fluid/parallel_executor.py index 179bac78ff..58380cf8e1 100644 --- a/python/paddle/fluid/parallel_executor.py +++ b/python/paddle/fluid/parallel_executor.py @@ -175,15 +175,6 @@ class ParallelExecutor(object): ) if use_cuda else framework.cpu_places() self._scope = scope if scope is not None else executor.global_scope() - if main_program is not None and main_program._enable_dgc: - assert build_strategy.num_trainers > 1, "dgc is not useful when num_trainers <= 1" - assert build_strategy.reduce_strategy == BuildStrategy.ReduceStrategy.AllReduce, "dgc \ - only used for allreduce" - - assert build_strategy.num_trainers * len( - self._places) > 1, "dgc is not useful for single card training" - assert use_cuda, "dgc only used under cuda" - main_program = main_program if main_program is not None \ else framework.default_main_program() diff --git a/python/paddle/fluid/tests/unittests/test_dist_base.py b/python/paddle/fluid/tests/unittests/test_dist_base.py index 4288a6c52a..ac0713d65e 100644 --- a/python/paddle/fluid/tests/unittests/test_dist_base.py +++ b/python/paddle/fluid/tests/unittests/test_dist_base.py @@ -334,10 +334,6 @@ class TestDistRunnerBase(object): build_stra.num_trainers = 1 build_stra.trainer_id = 0 - if args.use_dgc: - # fuse_all_reduce_ops require that gradients should not be sparse types - build_stra.fuse_all_reduce_ops = False - print_to_err(type(self).__name__, "begin to compile with data parallel") binary = compiler.CompiledProgram(trainer_prog).with_data_parallel( loss_name=avg_cost.name,