|
|
@ -191,8 +191,10 @@ class CollectiveOptimizer(DistributedOptimizer):
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
def __init__(self, optimizer, strategy=DistributedStrategy()):
|
|
|
|
def __init__(self, optimizer, strategy=DistributedStrategy()):
|
|
|
|
|
|
|
|
if strategy is None:
|
|
|
|
|
|
|
|
strategy = DistributedStrategy()
|
|
|
|
super(CollectiveOptimizer, self).__init__(optimizer, strategy)
|
|
|
|
super(CollectiveOptimizer, self).__init__(optimizer, strategy)
|
|
|
|
if strategy is not None and strategy.forward_recompute:
|
|
|
|
if strategy.forward_recompute:
|
|
|
|
self.forward_recompute = True
|
|
|
|
self.forward_recompute = True
|
|
|
|
self.recompute_checkpoints = strategy.recompute_checkpoints
|
|
|
|
self.recompute_checkpoints = strategy.recompute_checkpoints
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|