|
|
|
@ -118,6 +118,22 @@ class DistributedStrategy(object):
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
self.strategy = distributed_strategy_pb2.DistributedStrategy()
|
|
|
|
|
|
|
|
|
|
# Set the default values of the following flags to the ones set by users
|
|
|
|
|
key = 'FLAGS_cudnn_batchnorm_spatial_persistent'
|
|
|
|
|
if core.globals().is_public(key):
|
|
|
|
|
self.strategy.cudnn_batchnorm_spatial_persistent = bool(
|
|
|
|
|
core.globals()[key])
|
|
|
|
|
key = 'FLAGS_conv_workspace_size_limit'
|
|
|
|
|
if core.globals().is_public(key):
|
|
|
|
|
self.strategy.conv_workspace_size_limit = int(core.globals()[key])
|
|
|
|
|
key = 'FLAGS_cudnn_exhaustive_search'
|
|
|
|
|
if core.globals().is_public(key):
|
|
|
|
|
self.strategy.cudnn_exhaustive_search = bool(core.globals()[key])
|
|
|
|
|
key = 'FLAGS_sync_nccl_allreduce'
|
|
|
|
|
if core.globals().is_public(key):
|
|
|
|
|
self.strategy.sync_nccl_allreduce = bool(core.globals()[key])
|
|
|
|
|
|
|
|
|
|
self.__lock_attr = True
|
|
|
|
|
|
|
|
|
|
def __setattr__(self, key, value):
|
|
|
|
|