|
|
@ -45,6 +45,7 @@ class Collective(Fleet):
|
|
|
|
|
|
|
|
|
|
|
|
self.startup_program = None
|
|
|
|
self.startup_program = None
|
|
|
|
self._origin_program = None
|
|
|
|
self._origin_program = None
|
|
|
|
|
|
|
|
self._transpiled_program = None
|
|
|
|
self.main_program = None
|
|
|
|
self.main_program = None
|
|
|
|
|
|
|
|
|
|
|
|
def init_worker(self):
|
|
|
|
def init_worker(self):
|
|
|
@ -352,7 +353,8 @@ class CollectiveOptimizer(DistributedOptimizer):
|
|
|
|
parameter_list=parameter_list,
|
|
|
|
parameter_list=parameter_list,
|
|
|
|
no_grad_set=no_grad_set)
|
|
|
|
no_grad_set=no_grad_set)
|
|
|
|
|
|
|
|
|
|
|
|
fleet._origin_program = main_program
|
|
|
|
fleet._origin_program = main_program.clone(for_test=False)
|
|
|
|
|
|
|
|
fleet._transpiled_program = main_program
|
|
|
|
fleet.main_program = self._try_to_compile(startup_program, main_program)
|
|
|
|
fleet.main_program = self._try_to_compile(startup_program, main_program)
|
|
|
|
|
|
|
|
|
|
|
|
return optimize_ops, param_grads
|
|
|
|
return optimize_ops, param_grads
|
|
|
|